Skip to content

Commit cd639fc

Browse files
committed
finish msm
1 parent 5f4b31d commit cd639fc

File tree

3 files changed

+140
-7
lines changed

3 files changed

+140
-7
lines changed

ec/src/scalar_mul/variable_base/mod.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub mod stream_pippenger;
1616
pub use stream_pippenger::*;
1717

1818
use super::ScalarMul;
19+
use ark_ff::biginteger::{U128OrI128, U64OrI64};
1920

2021
#[cfg(all(
2122
target_has_atomic = "8",
@@ -644,6 +645,106 @@ pub fn msm_u128<V: VariableBaseMSM>(
644645
.sum()
645646
}
646647

648+
/// MSM over mixed-signed 64-bit integers using the small-scalar engine.
649+
pub fn msm_u64_or_i64<V: VariableBaseMSM>(
650+
mut bases: &[V::MulBase],
651+
mut scalars: &[U64OrI64],
652+
serial: bool,
653+
) -> V {
654+
// Partition by sign for better locality; build magnitudes as u64.
655+
let (negative_bases, non_negative_bases): (Vec<V::MulBase>, Vec<V::MulBase>) =
656+
bases
657+
.iter()
658+
.enumerate()
659+
.partition_map(|(i, b)| if scalars[i].is_negative() {
660+
Either::Left(b)
661+
} else {
662+
Either::Right(b)
663+
});
664+
let (negative_scalars, non_negative_scalars): (Vec<u64>, Vec<u64>) = scalars
665+
.iter()
666+
.partition_map(|s| match *s {
667+
U64OrI64::Unsigned(u) => Either::Right(u),
668+
U64OrI64::Signed(v) => {
669+
if v < 0 {
670+
Either::Left(v.unsigned_abs())
671+
} else {
672+
Either::Right(v as u64)
673+
}
674+
}
675+
});
676+
677+
if serial {
678+
return msm_serial::<V, _>(&non_negative_bases, &non_negative_scalars)
679+
- msm_serial::<V, _>(&negative_bases, &negative_scalars);
680+
} else {
681+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
682+
Some(chunk_size) => chunk_size,
683+
None => return V::zero(),
684+
};
685+
686+
let non_negative_msm: V = cfg_chunks!(non_negative_bases, chunk_size)
687+
.zip(cfg_chunks!(non_negative_scalars, chunk_size))
688+
.map(|(b, s)| msm_serial::<V, _>(b, s))
689+
.sum();
690+
let negative_msm: V = cfg_chunks!(negative_bases, chunk_size)
691+
.zip(cfg_chunks!(negative_scalars, chunk_size))
692+
.map(|(b, s)| msm_serial::<V, _>(b, s))
693+
.sum();
694+
non_negative_msm - negative_msm
695+
}
696+
}
697+
698+
/// MSM over mixed-signed 128-bit integers.
699+
pub fn msm_u128_or_i128<V: VariableBaseMSM>(
700+
mut bases: &[V::MulBase],
701+
mut scalars: &[U128OrI128],
702+
serial: bool,
703+
) -> V {
704+
// u128 path with sign partitioning.
705+
let (negative_bases, non_negative_bases): (Vec<V::MulBase>, Vec<V::MulBase>) = bases
706+
.iter()
707+
.enumerate()
708+
.partition_map(|(i, b)| if match scalars[i] { U128OrI128::Signed(v) if v < 0 => true, _ => false } {
709+
Either::Left(b)
710+
} else {
711+
Either::Right(b)
712+
});
713+
let (negative_scalars, non_negative_scalars): (Vec<u128>, Vec<u128>) = scalars
714+
.iter()
715+
.partition_map(|s| match *s {
716+
U128OrI128::Unsigned(u) => Either::Right(u),
717+
U128OrI128::Signed(v) => {
718+
let abs = v.unsigned_abs();
719+
if v < 0 {
720+
Either::Left(abs)
721+
} else {
722+
Either::Right(abs)
723+
}
724+
}
725+
});
726+
727+
if serial {
728+
msm_serial::<V, _>(&non_negative_bases, &non_negative_scalars)
729+
- msm_serial::<V, _>(&negative_bases, &negative_scalars)
730+
} else {
731+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
732+
Some(chunk_size) => chunk_size,
733+
None => return V::zero(),
734+
};
735+
736+
let non_negative_msm: V = cfg_chunks!(non_negative_bases, chunk_size)
737+
.zip(cfg_chunks!(non_negative_scalars, chunk_size))
738+
.map(|(b, s)| msm_serial::<V, _>(b, s))
739+
.sum();
740+
let negative_msm: V = cfg_chunks!(negative_bases, chunk_size)
741+
.zip(cfg_chunks!(negative_scalars, chunk_size))
742+
.map(|(b, s)| msm_serial::<V, _>(b, s))
743+
.sum();
744+
non_negative_msm - negative_msm
745+
}
746+
}
747+
647748
// Compute msm using windowed non-adjacent form
648749
fn msm_bigint_wnaf_parallel<V: VariableBaseMSM>(
649750
bases: &[V::MulBase],

ff/src/biginteger/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ use zeroize::Zeroize;
2929
#[macro_use]
3030
pub mod arithmetic;
3131

32+
pub mod types;
33+
pub use types::{U128OrI128, U64OrI64};
34+
3235
#[derive(Copy, Clone, PartialEq, Eq, Hash, Zeroize)]
3336
pub struct BigInt<const N: usize>(pub [u64; N]);
3437

ff/src/biginteger/types.rs

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,17 @@
66
///
77
/// Helper methods provide width-aware projections to `u64`/`i64` and a
88
/// canonical unsigned representation for lookup key construction.
9-
use allocative::Allocative;
109
use ark_serialize::{
1110
CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
1211
};
1312

14-
#[derive(Copy, Clone, Debug, PartialEq, Eq, Allocative)]
13+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1514
pub enum U64OrI64 {
1615
Unsigned(u64),
1716
Signed(i64),
1817
}
1918

20-
#[derive(Copy, Clone, Debug, PartialEq, Eq, Allocative)]
19+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
2120
pub enum U128OrI128 {
2221
Unsigned(u128),
2322
Signed(i128),
@@ -101,6 +100,21 @@ impl U64OrI64 {
101100
_ => panic!("{XLEN}-bit word size is unsupported"),
102101
}
103102
}
103+
104+
/// Returns true if the value is negative.
105+
#[inline]
106+
pub fn is_negative(&self) -> bool {
107+
match *self {
108+
U64OrI64::Unsigned(_) => false,
109+
U64OrI64::Signed(s) => s < 0,
110+
}
111+
}
112+
113+
/// Returns true if the value is nonnegative (>= 0).
114+
#[inline]
115+
pub fn is_positive(&self) -> bool {
116+
!self.is_negative()
117+
}
104118
}
105119

106120
impl U128OrI128 {
@@ -119,6 +133,21 @@ impl U128OrI128 {
119133
U128OrI128::Signed(s) => s,
120134
}
121135
}
136+
137+
/// Returns true if the value is negative.
138+
#[inline]
139+
pub fn is_negative(&self) -> bool {
140+
match *self {
141+
U128OrI128::Unsigned(_) => false,
142+
U128OrI128::Signed(s) => s < 0,
143+
}
144+
}
145+
146+
/// Returns true if the value is nonnegative (>= 0).
147+
#[inline]
148+
pub fn is_positive(&self) -> bool {
149+
!self.is_negative()
150+
}
122151
}
123152

124153
impl core::cmp::PartialOrd for U64OrI64 {
@@ -208,7 +237,7 @@ impl Valid for U128OrI128 {
208237
}
209238

210239
impl CanonicalSerialize for U64OrI64 {
211-
fn serialize_with_mode<W: std::io::Write>(
240+
fn serialize_with_mode<W: ark_std::io::Write>(
212241
&self,
213242
mut writer: W,
214243
compress: Compress,
@@ -235,7 +264,7 @@ impl CanonicalSerialize for U64OrI64 {
235264
}
236265

237266
impl CanonicalDeserialize for U64OrI64 {
238-
fn deserialize_with_mode<R: std::io::Read>(
267+
fn deserialize_with_mode<R: ark_std::io::Read>(
239268
mut reader: R,
240269
compress: Compress,
241270
_validate: Validate,
@@ -256,7 +285,7 @@ impl CanonicalDeserialize for U64OrI64 {
256285
}
257286

258287
impl CanonicalSerialize for U128OrI128 {
259-
fn serialize_with_mode<W: std::io::Write>(
288+
fn serialize_with_mode<W: ark_std::io::Write>(
260289
&self,
261290
mut writer: W,
262291
compress: Compress,
@@ -283,7 +312,7 @@ impl CanonicalSerialize for U128OrI128 {
283312
}
284313

285314
impl CanonicalDeserialize for U128OrI128 {
286-
fn deserialize_with_mode<R: std::io::Read>(
315+
fn deserialize_with_mode<R: ark_std::io::Read>(
287316
mut reader: R,
288317
compress: Compress,
289318
_validate: Validate,

0 commit comments

Comments
 (0)