@@ -62,6 +62,12 @@ pub trait VariableBaseMSM: ScalarMul + for<'a> AddAssign<&'a Self::Bucket> {
6262 . collect :: < Vec < _ > > ( ) ;
6363 Self :: msm_bigint ( bases, bigints. as_slice ( ) )
6464 }
65+ fn msm_unchecked_serial ( bases : & [ Self :: MulBase ] , scalars : & [ Self :: ScalarField ] ) -> Self {
66+ let bigints = cfg_into_iter ! ( scalars)
67+ . map ( |s| s. into_bigint ( ) )
68+ . collect :: < Vec < _ > > ( ) ;
69+ Self :: msm_bigint_serial ( bases, bigints. as_slice ( ) )
70+ }
6571
6672 /// Performs multi-scalar multiplication.
6773 ///
@@ -75,47 +81,62 @@ pub trait VariableBaseMSM: ScalarMul + for<'a> AddAssign<&'a Self::Bucket> {
7581 . then ( || Self :: msm_unchecked ( bases, scalars) )
7682 . ok_or_else ( || bases. len ( ) . min ( scalars. len ( ) ) )
7783 }
84+ fn msm_serial ( bases : & [ Self :: MulBase ] , scalars : & [ Self :: ScalarField ] ) -> Result < Self , usize > {
85+ ( bases. len ( ) == scalars. len ( ) )
86+ . then ( || Self :: msm_unchecked_serial ( bases, scalars) )
87+ . ok_or_else ( || bases. len ( ) . min ( scalars. len ( ) ) )
88+ }
7889
7990 /// Optimized implementation of multi-scalar multiplication.
8091 fn msm_bigint (
8192 bases : & [ Self :: MulBase ] ,
8293 bigints : & [ <Self :: ScalarField as PrimeField >:: BigInt ] ,
8394 ) -> Self {
8495 if Self :: NEGATION_IS_CHEAP {
85- msm_signed ( bases, bigints)
96+ msm_signed ( bases, bigints, false )
97+ } else {
98+ msm_unsigned ( bases, bigints, false )
99+ }
100+ }
101+ fn msm_bigint_serial (
102+ bases : & [ Self :: MulBase ] ,
103+ bigints : & [ <Self :: ScalarField as PrimeField >:: BigInt ] ,
104+ ) -> Self {
105+ if Self :: NEGATION_IS_CHEAP {
106+ msm_signed ( bases, bigints, true )
86107 } else {
87- msm_unsigned ( bases, bigints)
108+ msm_unsigned ( bases, bigints, true )
88109 }
89110 }
90111
91112 /// Performs multi-scalar multiplication when the scalars are known to be boolean.
92113 /// The default implementation is faster than [`Self::msm_bigint`].
93- fn msm_u1 ( bases : & [ Self :: MulBase ] , scalars : & [ bool ] ) -> Self {
94- msm_binary ( bases, scalars)
114+ fn msm_u1 ( bases : & [ Self :: MulBase ] , scalars : & [ bool ] , serial : bool ) -> Self {
115+ msm_binary ( bases, scalars, serial )
95116 }
96117
97118 /// Performs multi-scalar multiplication when the scalars are known to be `u8`-sized.
98119 /// The default implementation is faster than [`Self::msm_bigint`].
99- fn msm_u8 ( bases : & [ Self :: MulBase ] , scalars : & [ u8 ] ) -> Self {
100- msm_u8 ( bases, scalars)
120+ fn msm_u8 ( bases : & [ Self :: MulBase ] , scalars : & [ u8 ] , serial : bool ) -> Self {
121+ msm_u8 ( bases, scalars, serial )
101122 }
102123
103124 /// Performs multi-scalar multiplication when the scalars are known to be `u16`-sized.
104125 /// The default implementation is faster than [`Self::msm_bigint`].
105- fn msm_u16 ( bases : & [ Self :: MulBase ] , scalars : & [ u16 ] ) -> Self {
106- msm_u16 ( bases, scalars)
126+ fn msm_u16 ( bases : & [ Self :: MulBase ] , scalars : & [ u16 ] , serial : bool ) -> Self {
127+ msm_u16 ( bases, scalars, serial )
107128 }
108129
109130 /// Performs multi-scalar multiplication when the scalars are known to be `u32`-sized.
110131 /// The default implementation is faster than [`Self::msm_bigint`].
111- fn msm_u32 ( bases : & [ Self :: MulBase ] , scalars : & [ u32 ] ) -> Self {
112- msm_u32 ( bases, scalars)
132+ fn msm_u32 ( bases : & [ Self :: MulBase ] , scalars : & [ u32 ] , serial : bool ) -> Self {
133+ msm_u32 ( bases, scalars, serial )
113134 }
114135
115136 /// Performs multi-scalar multiplication when the scalars are known to be `u64`-sized.
116137 /// The default implementation is faster than [`Self::msm_bigint`].
117- fn msm_u64 ( bases : & [ Self :: MulBase ] , scalars : & [ u64 ] ) -> Self {
118- msm_u64 ( bases, scalars)
138+ fn msm_u64 ( bases : & [ Self :: MulBase ] , scalars : & [ u64 ] , serial : bool ) -> Self {
139+ msm_u64 ( bases, scalars, serial )
119140 }
120141
121142 /// Streaming multi-scalar multiplication algorithm with hard-coded chunk
@@ -188,6 +209,7 @@ fn iget_group<A: Send + Sync, B: Send + Sync>(
188209fn msm_unsigned < V : VariableBaseMSM > (
189210 bases : & [ V :: MulBase ] ,
190211 scalars : & [ <V :: ScalarField as PrimeField >:: BigInt ] ,
212+ serial : bool ,
191213) -> V {
192214 // Partition scalars according to whether
193215 // 1. they are in the range {0, 1};
@@ -253,12 +275,12 @@ fn msm_unsigned<V: VariableBaseMSM>(
253275 } ) ;
254276 let ( b7, s7) = uget_group ( & grouped[ s7..] , |i| ( bases[ i] , scalars[ i] ) ) ;
255277
256- let result: V = msm_binary :: < V > ( & b1, & s1)
257- + msm_u8 :: < V > ( & b3, & s3)
258- + msm_u16 :: < V > ( & b4, & s4)
259- + msm_u32 :: < V > ( & b5, & s5)
260- + msm_u64 :: < V > ( & b6, & s6)
261- + msm_bigint :: < V > ( & b7, & s7, V :: ScalarField :: MODULUS_BIT_SIZE as usize ) ;
278+ let result: V = msm_binary :: < V > ( & b1, & s1, serial )
279+ + msm_u8 :: < V > ( & b3, & s3, serial )
280+ + msm_u16 :: < V > ( & b4, & s4, serial )
281+ + msm_u32 :: < V > ( & b5, & s5, serial )
282+ + msm_u64 :: < V > ( & b6, & s6, serial )
283+ + msm_bigint :: < V > ( & b7, & s7, V :: ScalarField :: MODULUS_BIT_SIZE as usize , serial ) ;
262284 result. into ( )
263285}
264286
@@ -276,6 +298,7 @@ fn sub<B: BigInteger>(m: &B, scalar: &B) -> u64 {
276298fn msm_signed < V : VariableBaseMSM > (
277299 bases : & [ V :: MulBase ] ,
278300 scalars : & [ <V :: ScalarField as PrimeField >:: BigInt ] ,
301+ serial : bool ,
279302) -> V {
280303 // Partition scalars according to whether
281304 // 1. they are in the range {-1, 0, 1};
@@ -358,7 +381,7 @@ fn msm_signed<V: VariableBaseMSM>(
358381 let ( ib, is) = iget_group ( & grouped[ si1..su8] , |i| {
359382 ( bases[ i] , sub ( & m, & scalars[ i] ) == 1 )
360383 } ) ;
361- result = msm_binary :: < V > ( & ub, & us) - msm_binary :: < V > ( & ib, & is) ;
384+ result = msm_binary :: < V > ( & ub, & us, serial ) - msm_binary :: < V > ( & ib, & is, serial ) ;
362385
363386 // Handle positive and negative u8 scalars.
364387 let ( ub, us) = iget_group ( & grouped[ su8..si8] , |i| {
@@ -367,7 +390,7 @@ fn msm_signed<V: VariableBaseMSM>(
367390 let ( ib, is) = iget_group ( & grouped[ si8..su16] , |i| {
368391 ( bases[ i] , sub ( & m, & scalars[ i] ) as u8 )
369392 } ) ;
370- result += msm_u8 :: < V > ( & ub, & us) - msm_u8 :: < V > ( & ib, & is) ;
393+ result += msm_u8 :: < V > ( & ub, & us, serial ) - msm_u8 :: < V > ( & ib, & is, serial ) ;
371394
372395 // Handle positive and negative u16 scalars.
373396 let ( ub, us) = iget_group ( & grouped[ su16..si16] , |i| {
@@ -376,7 +399,7 @@ fn msm_signed<V: VariableBaseMSM>(
376399 let ( ib, is) = iget_group ( & grouped[ si16..su32] , |i| {
377400 ( bases[ i] , sub ( & m, & scalars[ i] ) as u16 )
378401 } ) ;
379- result += msm_u16 :: < V > ( & ub, & us) - msm_u16 :: < V > ( & ib, & is) ;
402+ result += msm_u16 :: < V > ( & ub, & us, serial ) - msm_u16 :: < V > ( & ib, & is, serial ) ;
380403
381404 // Handle positive and negative u32 scalars.
382405 let ( ub, us) = iget_group ( & grouped[ su32..si32] , |i| {
@@ -385,12 +408,12 @@ fn msm_signed<V: VariableBaseMSM>(
385408 let ( ib, is) = iget_group ( & grouped[ si32..su64] , |i| {
386409 ( bases[ i] , sub ( & m, & scalars[ i] ) as u32 )
387410 } ) ;
388- result += msm_u32 :: < V > ( & ub, & us) - msm_u32 :: < V > ( & ib, & is) ;
411+ result += msm_u32 :: < V > ( & ub, & us, serial ) - msm_u32 :: < V > ( & ib, & is, serial ) ;
389412
390413 // Handle positive and negative u64 scalars.
391414 let ( ub, us) = iget_group ( & grouped[ su64..si64] , |i| ( bases[ i] , scalars[ i] . as_ref ( ) [ 0 ] ) ) ;
392415 let ( ib, is) = iget_group ( & grouped[ si64..sf] , |i| ( bases[ i] , sub ( & m, & scalars[ i] ) ) ) ;
393- result += msm_u64 :: < V > ( & ub, & us) - msm_u64 :: < V > ( & ib, & is) ;
416+ result += msm_u64 :: < V > ( & ub, & us, serial ) - msm_u64 :: < V > ( & ib, & is, serial ) ;
394417
395418 // Handle the rest of the scalars.
396419 let ( bf, sf) = iget_group ( & grouped[ sf..] , |i| ( bases[ i] , scalars[ i] ) ) ;
@@ -399,15 +422,19 @@ fn msm_signed<V: VariableBaseMSM>(
399422 result. into ( )
400423}
401424
402- fn preamble < A , B > ( bases : & mut & [ A ] , scalars : & mut & [ B ] ) -> Option < usize > {
425+ fn preamble < A , B > ( bases : & mut & [ A ] , scalars : & mut & [ B ] , _serial : bool ) -> Option < usize > {
403426 let size = bases. len ( ) . min ( scalars. len ( ) ) ;
404427 if size == 0 {
405428 return None ;
406429 }
407430 #[ cfg( feature = "parallel" ) ]
408431 let chunk_size = {
409432 let chunk_size = size / rayon:: current_num_threads ( ) ;
410- if chunk_size == 0 { size } else { chunk_size }
433+ if _serial || chunk_size == 0 {
434+ size
435+ } else {
436+ chunk_size
437+ }
411438 } ;
412439 #[ cfg( not( feature = "parallel" ) ) ]
413440 let chunk_size = size;
@@ -419,8 +446,12 @@ fn preamble<A, B>(bases: &mut &[A], scalars: &mut &[B]) -> Option<usize> {
419446
420447/// Computes multi-scalar multiplication where the scalars
421448/// lie in the range {-1, 0, 1}.
422- pub fn msm_binary < V : VariableBaseMSM > ( mut bases : & [ V :: MulBase ] , mut scalars : & [ bool ] ) -> V {
423- let chunk_size = match preamble ( & mut bases, & mut scalars) {
449+ pub fn msm_binary < V : VariableBaseMSM > (
450+ mut bases : & [ V :: MulBase ] ,
451+ mut scalars : & [ bool ] ,
452+ serial : bool ,
453+ ) -> V {
454+ let chunk_size = match preamble ( & mut bases, & mut scalars, serial) {
424455 Some ( chunk_size) => chunk_size,
425456 None => return V :: zero ( ) ,
426457 } ;
@@ -438,47 +469,62 @@ pub fn msm_binary<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[b
438469 . sum ( )
439470}
440471
441- pub fn msm_u8 < V : VariableBaseMSM > ( mut bases : & [ V :: MulBase ] , mut scalars : & [ u8 ] ) -> V {
442- let chunk_size = match preamble ( & mut bases, & mut scalars) {
472+ pub fn msm_u8 < V : VariableBaseMSM > ( mut bases : & [ V :: MulBase ] , mut scalars : & [ u8 ] , serial : bool ) -> V {
473+ let chunk_size = match preamble ( & mut bases, & mut scalars, serial ) {
443474 Some ( chunk_size) => chunk_size,
444475 None => return V :: zero ( ) ,
445476 } ;
446477 cfg_chunks ! ( bases, chunk_size)
447478 . zip ( cfg_chunks ! ( scalars, chunk_size) )
448- . map ( |( bases, scalars) | msm_serial :: < V > ( bases, scalars) )
479+ . map ( |( bases, scalars) | msm_serial :: < V , _ > ( bases, scalars) )
449480 . sum ( )
450481}
451482
452- pub fn msm_u16 < V : VariableBaseMSM > ( mut bases : & [ V :: MulBase ] , mut scalars : & [ u16 ] ) -> V {
453- let chunk_size = match preamble ( & mut bases, & mut scalars) {
483+ pub fn msm_u16 < V : VariableBaseMSM > (
484+ mut bases : & [ V :: MulBase ] ,
485+ mut scalars : & [ u16 ] ,
486+ serial : bool ,
487+ ) -> V {
488+ let chunk_size = match preamble ( & mut bases, & mut scalars, serial) {
454489 Some ( chunk_size) => chunk_size,
455490 None => return V :: zero ( ) ,
456491 } ;
457492 cfg_chunks ! ( bases, chunk_size)
458493 . zip ( cfg_chunks ! ( scalars, chunk_size) )
459- . map ( |( bases, scalars) | msm_serial :: < V > ( bases, scalars) )
494+ . map ( |( bases, scalars) | msm_serial :: < V , _ > ( bases, scalars) )
460495 . sum ( )
461496}
462497
463- pub fn msm_u32 < V : VariableBaseMSM > ( mut bases : & [ V :: MulBase ] , mut scalars : & [ u32 ] ) -> V {
464- let chunk_size = match preamble ( & mut bases, & mut scalars) {
498+ pub fn msm_u32 < V : VariableBaseMSM > (
499+ mut bases : & [ V :: MulBase ] ,
500+ mut scalars : & [ u32 ] ,
501+ serial : bool ,
502+ ) -> V {
503+ let chunk_size = match preamble ( & mut bases, & mut scalars, serial) {
465504 Some ( chunk_size) => chunk_size,
466505 None => return V :: zero ( ) ,
467506 } ;
468507 cfg_chunks ! ( bases, chunk_size)
469508 . zip ( cfg_chunks ! ( scalars, chunk_size) )
470- . map ( |( bases, scalars) | msm_serial :: < V > ( bases, scalars) )
509+ . map ( |( bases, scalars) | msm_serial :: < V , _ > ( bases, scalars) )
471510 . sum ( )
472511}
473512
474- pub fn msm_u64 < V : VariableBaseMSM > ( mut bases : & [ V :: MulBase ] , mut scalars : & [ u64 ] ) -> V {
475- let chunk_size = match preamble ( & mut bases, & mut scalars) {
513+ pub fn msm_u64 < V : VariableBaseMSM > (
514+ mut bases : & [ V :: MulBase ] ,
515+ mut scalars : & [ u64 ] ,
516+ serial : bool ,
517+ ) -> V {
518+ if serial {
519+ return msm_serial :: < V , _ > ( bases, scalars) ;
520+ }
521+ let chunk_size = match preamble ( & mut bases, & mut scalars, serial) {
476522 Some ( chunk_size) => chunk_size,
477523 None => return V :: zero ( ) ,
478524 } ;
479525 cfg_chunks ! ( bases, chunk_size)
480526 . zip ( cfg_chunks ! ( scalars, chunk_size) )
481- . map ( |( bases, scalars) | msm_serial :: < V > ( bases, scalars) )
527+ . map ( |( bases, scalars) | msm_serial :: < V , _ > ( bases, scalars) )
482528 . sum ( )
483529}
484530
@@ -577,7 +623,11 @@ fn msm_bigint_wnaf<V: VariableBaseMSM>(
577623 cur_num_threads / THREADS_PER_CHUNK
578624 } ;
579625 let chunk_size = size / num_chunks;
580- if chunk_size == 0 { size } else { chunk_size }
626+ if chunk_size == 0 {
627+ size
628+ } else {
629+ chunk_size
630+ }
581631 } ;
582632 #[ cfg( not( feature = "parallel" ) ) ]
583633 let chunk_size = size;
@@ -608,8 +658,9 @@ fn msm_bigint<V: VariableBaseMSM>(
608658 mut bases : & [ V :: MulBase ] ,
609659 mut scalars : & [ <V :: ScalarField as PrimeField >:: BigInt ] ,
610660 num_bits : usize ,
661+ serial : bool ,
611662) -> V {
612- if preamble ( & mut bases, & mut scalars) . is_none ( ) {
663+ if preamble ( & mut bases, & mut scalars, serial ) . is_none ( ) {
613664 return V :: zero ( ) ;
614665 }
615666 let size = scalars. len ( ) ;
@@ -701,9 +752,9 @@ fn msm_bigint<V: VariableBaseMSM>(
701752 } )
702753}
703754
704- fn msm_serial < V : VariableBaseMSM > (
755+ fn msm_serial < V : VariableBaseMSM , S : Into < u64 > + Copy + Send + Sync > (
705756 bases : & [ V :: MulBase ] ,
706- scalars : & [ impl Into < u64 > + Copy + Send + Sync ] ,
757+ scalars : & [ S ] ,
707758) -> V {
708759 let c = if bases. len ( ) < 32 {
709760 3
@@ -717,7 +768,7 @@ fn msm_serial<V: VariableBaseMSM>(
717768 // We divide up the bits 0..num_bits into windows of size `c`, and
718769 // in parallel process each such window.
719770 let two_to_c = 1 << c;
720- let window_sums: Vec < _ > = ( 0 ..( core:: mem:: size_of :: < u64 > ( ) * 8 ) )
771+ let window_sums: Vec < _ > = ( 0 ..( core:: mem:: size_of :: < S > ( ) * 8 ) )
721772 . step_by ( c)
722773 . map ( |w_start| {
723774 let mut res = zero;
0 commit comments