@@ -8,6 +8,7 @@ use ark_std::{
88 vec:: Vec ,
99} ;
1010
11+ use itertools:: { Either , Itertools } ;
1112#[ cfg( feature = "parallel" ) ]
1213use rayon:: prelude:: * ;
1314
@@ -528,6 +529,121 @@ pub fn msm_u64<V: VariableBaseMSM>(
528529 . sum ( )
529530}
530531
532+ pub fn msm_i64 < V : VariableBaseMSM > (
533+ mut bases : & [ V :: MulBase ] ,
534+ mut scalars : & [ i64 ] ,
535+ serial : bool ,
536+ ) -> V {
537+ let ( negative_bases, non_negative_bases) : ( Vec < V :: MulBase > , Vec < V :: MulBase > ) =
538+ bases. iter ( ) . enumerate ( ) . partition_map ( |( i, b) | {
539+ if scalars[ i] . is_negative ( ) {
540+ Either :: Left ( b)
541+ } else {
542+ Either :: Right ( b)
543+ }
544+ } ) ;
545+ let ( negative_scalars, non_negative_scalars) : ( Vec < u64 > , Vec < u64 > ) =
546+ scalars. iter ( ) . partition_map ( |s| {
547+ if s. is_negative ( ) {
548+ Either :: Left ( s. unsigned_abs ( ) )
549+ } else {
550+ Either :: Right ( s. unsigned_abs ( ) )
551+ }
552+ } ) ;
553+ if serial {
554+ return msm_serial :: < V , _ > ( & non_negative_bases, & non_negative_scalars)
555+ - msm_serial :: < V , _ > ( & negative_bases, & negative_scalars) ;
556+ } else {
557+ let chunk_size = match preamble ( & mut bases, & mut scalars, serial) {
558+ Some ( chunk_size) => chunk_size,
559+ None => return V :: zero ( ) ,
560+ } ;
561+
562+ let non_negative_msm: V = cfg_chunks ! ( non_negative_bases, chunk_size)
563+ . zip ( cfg_chunks ! ( non_negative_scalars, chunk_size) )
564+ . map ( |( non_negative_bases, non_negative_scalars) | {
565+ msm_serial :: < V , _ > ( non_negative_bases, non_negative_scalars)
566+ } )
567+ . sum ( ) ;
568+ let negative_msm: V = cfg_chunks ! ( negative_bases, chunk_size)
569+ . zip ( cfg_chunks ! ( negative_scalars, chunk_size) )
570+ . map ( |( negative_bases, negative_scalars) | {
571+ msm_serial :: < V , _ > ( negative_bases, negative_scalars)
572+ } )
573+ . sum ( ) ;
574+ non_negative_msm - negative_msm
575+ }
576+ }
577+
578+ pub fn msm_i128 < V : VariableBaseMSM > (
579+ mut bases : & [ V :: MulBase ] ,
580+ mut scalars : & [ i128 ] ,
581+ serial : bool ,
582+ ) -> V {
583+ let ( negative_bases, non_negative_bases) : ( Vec < V :: MulBase > , Vec < V :: MulBase > ) =
584+ bases. iter ( ) . enumerate ( ) . partition_map ( |( i, b) | {
585+ if scalars[ i] . is_negative ( ) {
586+ Either :: Left ( b)
587+ } else {
588+ Either :: Right ( b)
589+ }
590+ } ) ;
591+ let ( negative_scalars, non_negative_scalars) : ( Vec < u64 > , Vec < u64 > ) =
592+ scalars. iter ( ) . partition_map ( |s| {
593+ let absolute_val = s. unsigned_abs ( ) ;
594+ debug_assert ! (
595+ absolute_val <= u64 :: MAX as u128 ,
596+ "msm_i128 only supports scalars in the range [-u64::MAX, u64::MAX]"
597+ ) ;
598+ if s. is_negative ( ) {
599+ Either :: Left ( absolute_val as u64 )
600+ } else {
601+ Either :: Right ( absolute_val as u64 )
602+ }
603+ } ) ;
604+ if serial {
605+ return msm_serial :: < V , _ > ( & non_negative_bases, & non_negative_scalars)
606+ - msm_serial :: < V , _ > ( & negative_bases, & negative_scalars) ;
607+ } else {
608+ let chunk_size = match preamble ( & mut bases, & mut scalars, serial) {
609+ Some ( chunk_size) => chunk_size,
610+ None => return V :: zero ( ) ,
611+ } ;
612+
613+ let non_negative_msm: V = cfg_chunks ! ( non_negative_bases, chunk_size)
614+ . zip ( cfg_chunks ! ( non_negative_scalars, chunk_size) )
615+ . map ( |( non_negative_bases, non_negative_scalars) | {
616+ msm_serial :: < V , _ > ( non_negative_bases, non_negative_scalars)
617+ } )
618+ . sum ( ) ;
619+ let negative_msm: V = cfg_chunks ! ( negative_bases, chunk_size)
620+ . zip ( cfg_chunks ! ( negative_scalars, chunk_size) )
621+ . map ( |( negative_bases, negative_scalars) | {
622+ msm_serial :: < V , _ > ( negative_bases, negative_scalars)
623+ } )
624+ . sum ( ) ;
625+ non_negative_msm - negative_msm
626+ }
627+ }
628+
629+ pub fn msm_u128 < V : VariableBaseMSM > (
630+ mut bases : & [ V :: MulBase ] ,
631+ mut scalars : & [ u128 ] ,
632+ serial : bool ,
633+ ) -> V {
634+ if serial {
635+ return msm_serial :: < V , _ > ( bases, scalars) ;
636+ }
637+ let chunk_size = match preamble ( & mut bases, & mut scalars, serial) {
638+ Some ( chunk_size) => chunk_size,
639+ None => return V :: zero ( ) ,
640+ } ;
641+ cfg_chunks ! ( bases, chunk_size)
642+ . zip ( cfg_chunks ! ( scalars, chunk_size) )
643+ . map ( |( bases, scalars) | msm_serial :: < V , _ > ( bases, scalars) )
644+ . sum ( )
645+ }
646+
531647// Compute msm using windowed non-adjacent form
532648fn msm_bigint_wnaf_parallel < V : VariableBaseMSM > (
533649 bases : & [ V :: MulBase ] ,
@@ -752,9 +868,9 @@ fn msm_bigint<V: VariableBaseMSM>(
752868 } )
753869}
754870
755- fn msm_serial < V : VariableBaseMSM , S : Into < u64 > + Copy + Send + Sync > (
756- bases : & [ V :: MulBase ] ,
757- scalars : & [ S ] ,
871+ fn msm_serial < ' a , V : VariableBaseMSM , S : Into < u128 > + Copy + Send + Sync + ' a > (
872+ bases : impl Iterable < Item = & ' a V :: MulBase > ,
873+ scalars : impl Iterable < Item = & ' a S > ,
758874) -> V {
759875 let c = if bases. len ( ) < 32 {
760876 3
@@ -778,7 +894,7 @@ fn msm_serial<V: VariableBaseMSM, S: Into<u64> + Copy + Send + Sync>(
778894 // pointer and an index into the original vectors.
779895 scalars
780896 . iter ( )
781- . zip ( bases)
897+ . zip ( bases. iter ( ) )
782898 . filter_map ( |( & s, b) | {
783899 let s = s. into ( ) ;
784900 ( s != 0 ) . then_some ( ( s, b) )
@@ -797,7 +913,7 @@ fn msm_serial<V: VariableBaseMSM, S: Into<u64> + Copy + Send + Sync>(
797913 scalar >>= w_start as u32 ;
798914
799915 // We mod the remaining bits by 2^{window size}, thus taking `c` bits.
800- scalar %= two_to_c as u64 ;
916+ scalar %= two_to_c as u128 ;
801917
802918 // If the scalar is non-zero, we update the corresponding
803919 // bucket.
0 commit comments