Skip to content

Commit 5ebdb17

Browse files
authored
Merge pull request #14 from a16z/feat/msm-i64-i128
Add msm_i64, msm_i128, msm_u128
2 parents 44143c8 + 84825f4 commit 5ebdb17

File tree

1 file changed

+121
-5
lines changed
  • ec/src/scalar_mul/variable_base

1 file changed

+121
-5
lines changed

ec/src/scalar_mul/variable_base/mod.rs

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use ark_std::{
88
vec::Vec,
99
};
1010

11+
use itertools::{Either, Itertools};
1112
#[cfg(feature = "parallel")]
1213
use rayon::prelude::*;
1314

@@ -528,6 +529,121 @@ pub fn msm_u64<V: VariableBaseMSM>(
528529
.sum()
529530
}
530531

532+
pub fn msm_i64<V: VariableBaseMSM>(
533+
mut bases: &[V::MulBase],
534+
mut scalars: &[i64],
535+
serial: bool,
536+
) -> V {
537+
let (negative_bases, non_negative_bases): (Vec<V::MulBase>, Vec<V::MulBase>) =
538+
bases.iter().enumerate().partition_map(|(i, b)| {
539+
if scalars[i].is_negative() {
540+
Either::Left(b)
541+
} else {
542+
Either::Right(b)
543+
}
544+
});
545+
let (negative_scalars, non_negative_scalars): (Vec<u64>, Vec<u64>) =
546+
scalars.iter().partition_map(|s| {
547+
if s.is_negative() {
548+
Either::Left(s.unsigned_abs())
549+
} else {
550+
Either::Right(s.unsigned_abs())
551+
}
552+
});
553+
if serial {
554+
return msm_serial::<V, _>(&non_negative_bases, &non_negative_scalars)
555+
- msm_serial::<V, _>(&negative_bases, &negative_scalars);
556+
} else {
557+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
558+
Some(chunk_size) => chunk_size,
559+
None => return V::zero(),
560+
};
561+
562+
let non_negative_msm: V = cfg_chunks!(non_negative_bases, chunk_size)
563+
.zip(cfg_chunks!(non_negative_scalars, chunk_size))
564+
.map(|(non_negative_bases, non_negative_scalars)| {
565+
msm_serial::<V, _>(non_negative_bases, non_negative_scalars)
566+
})
567+
.sum();
568+
let negative_msm: V = cfg_chunks!(negative_bases, chunk_size)
569+
.zip(cfg_chunks!(negative_scalars, chunk_size))
570+
.map(|(negative_bases, negative_scalars)| {
571+
msm_serial::<V, _>(negative_bases, negative_scalars)
572+
})
573+
.sum();
574+
non_negative_msm - negative_msm
575+
}
576+
}
577+
578+
pub fn msm_i128<V: VariableBaseMSM>(
579+
mut bases: &[V::MulBase],
580+
mut scalars: &[i128],
581+
serial: bool,
582+
) -> V {
583+
let (negative_bases, non_negative_bases): (Vec<V::MulBase>, Vec<V::MulBase>) =
584+
bases.iter().enumerate().partition_map(|(i, b)| {
585+
if scalars[i].is_negative() {
586+
Either::Left(b)
587+
} else {
588+
Either::Right(b)
589+
}
590+
});
591+
let (negative_scalars, non_negative_scalars): (Vec<u64>, Vec<u64>) =
592+
scalars.iter().partition_map(|s| {
593+
let absolute_val = s.unsigned_abs();
594+
debug_assert!(
595+
absolute_val <= u64::MAX as u128,
596+
"msm_i128 only supports scalars in the range [-u64::MAX, u64::MAX]"
597+
);
598+
if s.is_negative() {
599+
Either::Left(absolute_val as u64)
600+
} else {
601+
Either::Right(absolute_val as u64)
602+
}
603+
});
604+
if serial {
605+
return msm_serial::<V, _>(&non_negative_bases, &non_negative_scalars)
606+
- msm_serial::<V, _>(&negative_bases, &negative_scalars);
607+
} else {
608+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
609+
Some(chunk_size) => chunk_size,
610+
None => return V::zero(),
611+
};
612+
613+
let non_negative_msm: V = cfg_chunks!(non_negative_bases, chunk_size)
614+
.zip(cfg_chunks!(non_negative_scalars, chunk_size))
615+
.map(|(non_negative_bases, non_negative_scalars)| {
616+
msm_serial::<V, _>(non_negative_bases, non_negative_scalars)
617+
})
618+
.sum();
619+
let negative_msm: V = cfg_chunks!(negative_bases, chunk_size)
620+
.zip(cfg_chunks!(negative_scalars, chunk_size))
621+
.map(|(negative_bases, negative_scalars)| {
622+
msm_serial::<V, _>(negative_bases, negative_scalars)
623+
})
624+
.sum();
625+
non_negative_msm - negative_msm
626+
}
627+
}
628+
629+
pub fn msm_u128<V: VariableBaseMSM>(
630+
mut bases: &[V::MulBase],
631+
mut scalars: &[u128],
632+
serial: bool,
633+
) -> V {
634+
if serial {
635+
return msm_serial::<V, _>(bases, scalars);
636+
}
637+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
638+
Some(chunk_size) => chunk_size,
639+
None => return V::zero(),
640+
};
641+
cfg_chunks!(bases, chunk_size)
642+
.zip(cfg_chunks!(scalars, chunk_size))
643+
.map(|(bases, scalars)| msm_serial::<V, _>(bases, scalars))
644+
.sum()
645+
}
646+
531647
// Compute msm using windowed non-adjacent form
532648
fn msm_bigint_wnaf_parallel<V: VariableBaseMSM>(
533649
bases: &[V::MulBase],
@@ -752,9 +868,9 @@ fn msm_bigint<V: VariableBaseMSM>(
752868
})
753869
}
754870

755-
fn msm_serial<V: VariableBaseMSM, S: Into<u64> + Copy + Send + Sync>(
756-
bases: &[V::MulBase],
757-
scalars: &[S],
871+
fn msm_serial<'a, V: VariableBaseMSM, S: Into<u128> + Copy + Send + Sync + 'a>(
872+
bases: impl Iterable<Item = &'a V::MulBase>,
873+
scalars: impl Iterable<Item = &'a S>,
758874
) -> V {
759875
let c = if bases.len() < 32 {
760876
3
@@ -778,7 +894,7 @@ fn msm_serial<V: VariableBaseMSM, S: Into<u64> + Copy + Send + Sync>(
778894
// pointer and an index into the original vectors.
779895
scalars
780896
.iter()
781-
.zip(bases)
897+
.zip(bases.iter())
782898
.filter_map(|(&s, b)| {
783899
let s = s.into();
784900
(s != 0).then_some((s, b))
@@ -797,7 +913,7 @@ fn msm_serial<V: VariableBaseMSM, S: Into<u64> + Copy + Send + Sync>(
797913
scalar >>= w_start as u32;
798914

799915
// We mod the remaining bits by 2^{window size}, thus taking `c` bits.
800-
scalar %= two_to_c as u64;
916+
scalar %= two_to_c as u128;
801917

802918
// If the scalar is non-zero, we update the corresponding
803919
// bucket.

0 commit comments

Comments
 (0)