22//!
33//! Implements efficient batch addition of affine elliptic curve points
44//! using Montgomery's batch inversion trick to minimize field inversions.
5-
5+ //! @TODO(markosg04) duplicate group elements?
66use ark_bn254:: G1Affine ;
77use ark_ec:: AffineRepr ;
88use rayon:: prelude:: * ;
@@ -63,7 +63,7 @@ pub fn batch_g1_additions(bases: &[G1Affine], indices: &[usize]) -> G1Affine {
6363 let lambda = ( p2. y - p1. y ) * inv;
6464 let x3 = lambda * lambda - p1. x - p2. x ;
6565 let y3 = lambda * ( p1. x - x3) - p1. y ;
66- G1Affine :: new ( x3, y3)
66+ G1Affine :: new_unchecked ( x3, y3)
6767 } )
6868 . collect ( ) ;
6969
@@ -78,6 +78,95 @@ pub fn batch_g1_additions(bases: &[G1Affine], indices: &[usize]) -> G1Affine {
7878 points[ 0 ]
7979}
8080
81+ /// Performs multiple batch additions of G1 affine points in parallel.
82+ ///
83+ /// Given a slice of base points and multiple sets of indices, computes the sum
84+ /// for each set of indices. All additions across all batches share the same
85+ /// batch inversion.
86+ ///
87+ /// # Arguments
88+ /// * `bases` - Slice of G1 affine points to select from
89+ /// * `indices_sets` - Vector of index vectors, each specifying which points to sum
90+ ///
91+ /// # Returns
92+ /// Vector of sums, one for each index set
93+ pub fn batch_g1_additions_multi ( bases : & [ G1Affine ] , indices_sets : & [ Vec < usize > ] ) -> Vec < G1Affine > {
94+ if indices_sets. is_empty ( ) {
95+ return vec ! [ ] ;
96+ }
97+
98+ // Initialize working sets for each batch
99+ let mut working_sets: Vec < Vec < G1Affine > > = indices_sets
100+ . par_iter ( )
101+ . map ( |indices| {
102+ if indices. is_empty ( ) {
103+ vec ! [ G1Affine :: zero( ) ]
104+ } else if indices. len ( ) == 1 {
105+ vec ! [ bases[ indices[ 0 ] ] ]
106+ } else {
107+ indices. iter ( ) . map ( |& i| bases[ i] ) . collect ( )
108+ }
109+ } )
110+ . collect ( ) ;
111+
112+ // Continue until all sets have been reduced to a single point
113+ loop {
114+ // Count total number of pairs across all sets
115+ let total_pairs: usize = working_sets. iter ( ) . map ( |set| set. len ( ) / 2 ) . sum ( ) ;
116+
117+ if total_pairs == 0 {
118+ break ;
119+ }
120+
121+ // Collect all denominators across all sets
122+ let mut all_denominators = Vec :: with_capacity ( total_pairs) ;
123+ let mut pair_info = Vec :: with_capacity ( total_pairs) ;
124+
125+ for ( set_idx, set) in working_sets. iter ( ) . enumerate ( ) {
126+ let pairs_in_set = set. len ( ) / 2 ;
127+ for pair_idx in 0 ..pairs_in_set {
128+ let p1 = set[ pair_idx * 2 ] ;
129+ let p2 = set[ pair_idx * 2 + 1 ] ;
130+ all_denominators. push ( p2. x - p1. x ) ;
131+ pair_info. push ( ( set_idx, pair_idx) ) ;
132+ }
133+ }
134+
135+ // Batch invert all denominators at once
136+ let mut inverses = all_denominators;
137+ ark_ff:: fields:: batch_inversion ( & mut inverses) ;
138+
139+ // Apply additions using the inverted denominators
140+ let mut new_working_sets: Vec < Vec < G1Affine > > = working_sets
141+ . iter ( )
142+ . map ( |set| Vec :: with_capacity ( ( set. len ( ) + 1 ) / 2 ) )
143+ . collect ( ) ;
144+
145+ // Process additions and maintain order
146+ for ( ( set_idx, pair_idx) , inv) in pair_info. iter ( ) . zip ( inverses. iter ( ) ) {
147+ let set = & working_sets[ * set_idx] ;
148+ let p1 = set[ * pair_idx * 2 ] ;
149+ let p2 = set[ * pair_idx * 2 + 1 ] ;
150+ let lambda = ( p2. y - p1. y ) * inv;
151+ let x3 = lambda * lambda - p1. x - p2. x ;
152+ let y3 = lambda * ( p1. x - x3) - p1. y ;
153+ new_working_sets[ * set_idx] . push ( G1Affine :: new_unchecked ( x3, y3) ) ;
154+ }
155+
156+ // Handle odd elements
157+ for ( set_idx, set) in working_sets. iter ( ) . enumerate ( ) {
158+ if set. len ( ) % 2 == 1 {
159+ new_working_sets[ set_idx] . push ( set[ set. len ( ) - 1 ] ) ;
160+ }
161+ }
162+
163+ working_sets = new_working_sets;
164+ }
165+
166+ // Extract final results
167+ working_sets. into_iter ( ) . map ( |set| set[ 0 ] ) . collect ( )
168+ }
169+
81170#[ cfg( test) ]
82171mod tests {
83172 use super :: * ;
@@ -123,8 +212,8 @@ mod tests {
123212 fn test_stress_test_correctness ( ) {
124213 let mut rng = ark_std:: test_rng ( ) ;
125214
126- let base_size = 10000 ;
127- let indices_size = 5000 ;
215+ let base_size = 100000 ;
216+ let indices_size = 50000 ;
128217
129218 let bases: Vec < G1Affine > = ( 0 ..base_size) . map ( |_| G1Affine :: rand ( & mut rng) ) . collect ( ) ;
130219
@@ -134,7 +223,6 @@ mod tests {
134223
135224 let batch_result = batch_g1_additions ( & bases, & indices) ;
136225
137- // Compute expected result using naive sequential addition
138226 let mut expected = G1Affine :: zero ( ) ;
139227 for & idx in & indices {
140228 expected = ( expected + bases[ idx] ) . into_affine ( ) ;
@@ -145,4 +233,34 @@ mod tests {
145233 "Stress test failed: batch result doesn't match expected sum"
146234 ) ;
147235 }
236+
237+ #[ test]
238+ fn test_batch_additions_multi_large ( ) {
239+ let mut rng = ark_std:: test_rng ( ) ;
240+
241+ let base_size = 10000 ;
242+ let num_batches = 50 ;
243+
244+ let bases: Vec < G1Affine > = ( 0 ..base_size) . map ( |_| G1Affine :: rand ( & mut rng) ) . collect ( ) ;
245+
246+ let indices_sets: Vec < Vec < usize > > = ( 0 ..num_batches)
247+ . map ( |_| {
248+ let size = ( rng. next_u64 ( ) as usize ) % 100 + 1 ;
249+ ( 0 ..size)
250+ . map ( |_| ( rng. next_u64 ( ) as usize ) % base_size)
251+ . collect ( )
252+ } )
253+ . collect ( ) ;
254+
255+ let batch_results = batch_g1_additions_multi ( & bases, & indices_sets) ;
256+
257+ for ( i, ( result, indices) ) in batch_results. iter ( ) . zip ( indices_sets. iter ( ) ) . enumerate ( ) {
258+ let single_result = batch_g1_additions ( & bases, indices) ;
259+ assert_eq ! (
260+ * result, single_result,
261+ "Multi vs single mismatch at batch {}" ,
262+ i
263+ ) ;
264+ }
265+ }
148266}
0 commit comments