Skip to content

Commit 77edc1a

Browse files
committed
Expose serial parameter for MSM functions
1 parent efc56e0 commit 77edc1a

File tree

1 file changed

+95
-44
lines changed
  • ec/src/scalar_mul/variable_base

1 file changed

+95
-44
lines changed

ec/src/scalar_mul/variable_base/mod.rs

Lines changed: 95 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ pub trait VariableBaseMSM: ScalarMul + for<'a> AddAssign<&'a Self::Bucket> {
6262
.collect::<Vec<_>>();
6363
Self::msm_bigint(bases, bigints.as_slice())
6464
}
65+
fn msm_unchecked_serial(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Self {
66+
let bigints = cfg_into_iter!(scalars)
67+
.map(|s| s.into_bigint())
68+
.collect::<Vec<_>>();
69+
Self::msm_bigint_serial(bases, bigints.as_slice())
70+
}
6571

6672
/// Performs multi-scalar multiplication.
6773
///
@@ -75,47 +81,62 @@ pub trait VariableBaseMSM: ScalarMul + for<'a> AddAssign<&'a Self::Bucket> {
7581
.then(|| Self::msm_unchecked(bases, scalars))
7682
.ok_or_else(|| bases.len().min(scalars.len()))
7783
}
84+
fn msm_serial(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Result<Self, usize> {
85+
(bases.len() == scalars.len())
86+
.then(|| Self::msm_unchecked_serial(bases, scalars))
87+
.ok_or_else(|| bases.len().min(scalars.len()))
88+
}
7889

7990
/// Optimized implementation of multi-scalar multiplication.
8091
fn msm_bigint(
8192
bases: &[Self::MulBase],
8293
bigints: &[<Self::ScalarField as PrimeField>::BigInt],
8394
) -> Self {
8495
if Self::NEGATION_IS_CHEAP {
85-
msm_signed(bases, bigints)
96+
msm_signed(bases, bigints, false)
97+
} else {
98+
msm_unsigned(bases, bigints, false)
99+
}
100+
}
101+
fn msm_bigint_serial(
102+
bases: &[Self::MulBase],
103+
bigints: &[<Self::ScalarField as PrimeField>::BigInt],
104+
) -> Self {
105+
if Self::NEGATION_IS_CHEAP {
106+
msm_signed(bases, bigints, true)
86107
} else {
87-
msm_unsigned(bases, bigints)
108+
msm_unsigned(bases, bigints, true)
88109
}
89110
}
90111

91112
/// Performs multi-scalar multiplication when the scalars are known to be boolean.
92113
/// The default implementation is faster than [`Self::msm_bigint`].
93-
fn msm_u1(bases: &[Self::MulBase], scalars: &[bool]) -> Self {
94-
msm_binary(bases, scalars)
114+
fn msm_u1(bases: &[Self::MulBase], scalars: &[bool], serial: bool) -> Self {
115+
msm_binary(bases, scalars, serial)
95116
}
96117

97118
/// Performs multi-scalar multiplication when the scalars are known to be `u8`-sized.
98119
/// The default implementation is faster than [`Self::msm_bigint`].
99-
fn msm_u8(bases: &[Self::MulBase], scalars: &[u8]) -> Self {
100-
msm_u8(bases, scalars)
120+
fn msm_u8(bases: &[Self::MulBase], scalars: &[u8], serial: bool) -> Self {
121+
msm_u8(bases, scalars, serial)
101122
}
102123

103124
/// Performs multi-scalar multiplication when the scalars are known to be `u16`-sized.
104125
/// The default implementation is faster than [`Self::msm_bigint`].
105-
fn msm_u16(bases: &[Self::MulBase], scalars: &[u16]) -> Self {
106-
msm_u16(bases, scalars)
126+
fn msm_u16(bases: &[Self::MulBase], scalars: &[u16], serial: bool) -> Self {
127+
msm_u16(bases, scalars, serial)
107128
}
108129

109130
/// Performs multi-scalar multiplication when the scalars are known to be `u32`-sized.
110131
/// The default implementation is faster than [`Self::msm_bigint`].
111-
fn msm_u32(bases: &[Self::MulBase], scalars: &[u32]) -> Self {
112-
msm_u32(bases, scalars)
132+
fn msm_u32(bases: &[Self::MulBase], scalars: &[u32], serial: bool) -> Self {
133+
msm_u32(bases, scalars, serial)
113134
}
114135

115136
/// Performs multi-scalar multiplication when the scalars are known to be `u64`-sized.
116137
/// The default implementation is faster than [`Self::msm_bigint`].
117-
fn msm_u64(bases: &[Self::MulBase], scalars: &[u64]) -> Self {
118-
msm_u64(bases, scalars)
138+
fn msm_u64(bases: &[Self::MulBase], scalars: &[u64], serial: bool) -> Self {
139+
msm_u64(bases, scalars, serial)
119140
}
120141

121142
/// Streaming multi-scalar multiplication algorithm with hard-coded chunk
@@ -188,6 +209,7 @@ fn iget_group<A: Send + Sync, B: Send + Sync>(
188209
fn msm_unsigned<V: VariableBaseMSM>(
189210
bases: &[V::MulBase],
190211
scalars: &[<V::ScalarField as PrimeField>::BigInt],
212+
serial: bool,
191213
) -> V {
192214
// Partition scalars according to whether
193215
// 1. they are in the range {0, 1};
@@ -253,12 +275,12 @@ fn msm_unsigned<V: VariableBaseMSM>(
253275
});
254276
let (b7, s7) = uget_group(&grouped[s7..], |i| (bases[i], scalars[i]));
255277

256-
let result: V = msm_binary::<V>(&b1, &s1)
257-
+ msm_u8::<V>(&b3, &s3)
258-
+ msm_u16::<V>(&b4, &s4)
259-
+ msm_u32::<V>(&b5, &s5)
260-
+ msm_u64::<V>(&b6, &s6)
261-
+ msm_bigint::<V>(&b7, &s7, V::ScalarField::MODULUS_BIT_SIZE as usize);
278+
let result: V = msm_binary::<V>(&b1, &s1, serial)
279+
+ msm_u8::<V>(&b3, &s3, serial)
280+
+ msm_u16::<V>(&b4, &s4, serial)
281+
+ msm_u32::<V>(&b5, &s5, serial)
282+
+ msm_u64::<V>(&b6, &s6, serial)
283+
+ msm_bigint::<V>(&b7, &s7, V::ScalarField::MODULUS_BIT_SIZE as usize, serial);
262284
result.into()
263285
}
264286

@@ -276,6 +298,7 @@ fn sub<B: BigInteger>(m: &B, scalar: &B) -> u64 {
276298
fn msm_signed<V: VariableBaseMSM>(
277299
bases: &[V::MulBase],
278300
scalars: &[<V::ScalarField as PrimeField>::BigInt],
301+
serial: bool,
279302
) -> V {
280303
// Partition scalars according to whether
281304
// 1. they are in the range {-1, 0, 1};
@@ -358,7 +381,7 @@ fn msm_signed<V: VariableBaseMSM>(
358381
let (ib, is) = iget_group(&grouped[si1..su8], |i| {
359382
(bases[i], sub(&m, &scalars[i]) == 1)
360383
});
361-
result = msm_binary::<V>(&ub, &us) - msm_binary::<V>(&ib, &is);
384+
result = msm_binary::<V>(&ub, &us, serial) - msm_binary::<V>(&ib, &is, serial);
362385

363386
// Handle positive and negative u8 scalars.
364387
let (ub, us) = iget_group(&grouped[su8..si8], |i| {
@@ -367,7 +390,7 @@ fn msm_signed<V: VariableBaseMSM>(
367390
let (ib, is) = iget_group(&grouped[si8..su16], |i| {
368391
(bases[i], sub(&m, &scalars[i]) as u8)
369392
});
370-
result += msm_u8::<V>(&ub, &us) - msm_u8::<V>(&ib, &is);
393+
result += msm_u8::<V>(&ub, &us, serial) - msm_u8::<V>(&ib, &is, serial);
371394

372395
// Handle positive and negative u16 scalars.
373396
let (ub, us) = iget_group(&grouped[su16..si16], |i| {
@@ -376,7 +399,7 @@ fn msm_signed<V: VariableBaseMSM>(
376399
let (ib, is) = iget_group(&grouped[si16..su32], |i| {
377400
(bases[i], sub(&m, &scalars[i]) as u16)
378401
});
379-
result += msm_u16::<V>(&ub, &us) - msm_u16::<V>(&ib, &is);
402+
result += msm_u16::<V>(&ub, &us, serial) - msm_u16::<V>(&ib, &is, serial);
380403

381404
// Handle positive and negative u32 scalars.
382405
let (ub, us) = iget_group(&grouped[su32..si32], |i| {
@@ -385,12 +408,12 @@ fn msm_signed<V: VariableBaseMSM>(
385408
let (ib, is) = iget_group(&grouped[si32..su64], |i| {
386409
(bases[i], sub(&m, &scalars[i]) as u32)
387410
});
388-
result += msm_u32::<V>(&ub, &us) - msm_u32::<V>(&ib, &is);
411+
result += msm_u32::<V>(&ub, &us, serial) - msm_u32::<V>(&ib, &is, serial);
389412

390413
// Handle positive and negative u64 scalars.
391414
let (ub, us) = iget_group(&grouped[su64..si64], |i| (bases[i], scalars[i].as_ref()[0]));
392415
let (ib, is) = iget_group(&grouped[si64..sf], |i| (bases[i], sub(&m, &scalars[i])));
393-
result += msm_u64::<V>(&ub, &us) - msm_u64::<V>(&ib, &is);
416+
result += msm_u64::<V>(&ub, &us, serial) - msm_u64::<V>(&ib, &is, serial);
394417

395418
// Handle the rest of the scalars.
396419
let (bf, sf) = iget_group(&grouped[sf..], |i| (bases[i], scalars[i]));
@@ -399,15 +422,19 @@ fn msm_signed<V: VariableBaseMSM>(
399422
result.into()
400423
}
401424

402-
fn preamble<A, B>(bases: &mut &[A], scalars: &mut &[B]) -> Option<usize> {
425+
fn preamble<A, B>(bases: &mut &[A], scalars: &mut &[B], _serial: bool) -> Option<usize> {
403426
let size = bases.len().min(scalars.len());
404427
if size == 0 {
405428
return None;
406429
}
407430
#[cfg(feature = "parallel")]
408431
let chunk_size = {
409432
let chunk_size = size / rayon::current_num_threads();
410-
if chunk_size == 0 { size } else { chunk_size }
433+
if _serial || chunk_size == 0 {
434+
size
435+
} else {
436+
chunk_size
437+
}
411438
};
412439
#[cfg(not(feature = "parallel"))]
413440
let chunk_size = size;
@@ -419,8 +446,12 @@ fn preamble<A, B>(bases: &mut &[A], scalars: &mut &[B]) -> Option<usize> {
419446

420447
/// Computes multi-scalar multiplication where the scalars
421448
/// lie in the range {-1, 0, 1}.
422-
pub fn msm_binary<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[bool]) -> V {
423-
let chunk_size = match preamble(&mut bases, &mut scalars) {
449+
pub fn msm_binary<V: VariableBaseMSM>(
450+
mut bases: &[V::MulBase],
451+
mut scalars: &[bool],
452+
serial: bool,
453+
) -> V {
454+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
424455
Some(chunk_size) => chunk_size,
425456
None => return V::zero(),
426457
};
@@ -438,47 +469,62 @@ pub fn msm_binary<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[b
438469
.sum()
439470
}
440471

441-
pub fn msm_u8<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u8]) -> V {
442-
let chunk_size = match preamble(&mut bases, &mut scalars) {
472+
pub fn msm_u8<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u8], serial: bool) -> V {
473+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
443474
Some(chunk_size) => chunk_size,
444475
None => return V::zero(),
445476
};
446477
cfg_chunks!(bases, chunk_size)
447478
.zip(cfg_chunks!(scalars, chunk_size))
448-
.map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
479+
.map(|(bases, scalars)| msm_serial::<V, _>(bases, scalars))
449480
.sum()
450481
}
451482

452-
pub fn msm_u16<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u16]) -> V {
453-
let chunk_size = match preamble(&mut bases, &mut scalars) {
483+
pub fn msm_u16<V: VariableBaseMSM>(
484+
mut bases: &[V::MulBase],
485+
mut scalars: &[u16],
486+
serial: bool,
487+
) -> V {
488+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
454489
Some(chunk_size) => chunk_size,
455490
None => return V::zero(),
456491
};
457492
cfg_chunks!(bases, chunk_size)
458493
.zip(cfg_chunks!(scalars, chunk_size))
459-
.map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
494+
.map(|(bases, scalars)| msm_serial::<V, _>(bases, scalars))
460495
.sum()
461496
}
462497

463-
pub fn msm_u32<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u32]) -> V {
464-
let chunk_size = match preamble(&mut bases, &mut scalars) {
498+
pub fn msm_u32<V: VariableBaseMSM>(
499+
mut bases: &[V::MulBase],
500+
mut scalars: &[u32],
501+
serial: bool,
502+
) -> V {
503+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
465504
Some(chunk_size) => chunk_size,
466505
None => return V::zero(),
467506
};
468507
cfg_chunks!(bases, chunk_size)
469508
.zip(cfg_chunks!(scalars, chunk_size))
470-
.map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
509+
.map(|(bases, scalars)| msm_serial::<V, _>(bases, scalars))
471510
.sum()
472511
}
473512

474-
pub fn msm_u64<V: VariableBaseMSM>(mut bases: &[V::MulBase], mut scalars: &[u64]) -> V {
475-
let chunk_size = match preamble(&mut bases, &mut scalars) {
513+
pub fn msm_u64<V: VariableBaseMSM>(
514+
mut bases: &[V::MulBase],
515+
mut scalars: &[u64],
516+
serial: bool,
517+
) -> V {
518+
if serial {
519+
return msm_serial::<V, _>(bases, scalars);
520+
}
521+
let chunk_size = match preamble(&mut bases, &mut scalars, serial) {
476522
Some(chunk_size) => chunk_size,
477523
None => return V::zero(),
478524
};
479525
cfg_chunks!(bases, chunk_size)
480526
.zip(cfg_chunks!(scalars, chunk_size))
481-
.map(|(bases, scalars)| msm_serial::<V>(bases, scalars))
527+
.map(|(bases, scalars)| msm_serial::<V, _>(bases, scalars))
482528
.sum()
483529
}
484530

@@ -577,7 +623,11 @@ fn msm_bigint_wnaf<V: VariableBaseMSM>(
577623
cur_num_threads / THREADS_PER_CHUNK
578624
};
579625
let chunk_size = size / num_chunks;
580-
if chunk_size == 0 { size } else { chunk_size }
626+
if chunk_size == 0 {
627+
size
628+
} else {
629+
chunk_size
630+
}
581631
};
582632
#[cfg(not(feature = "parallel"))]
583633
let chunk_size = size;
@@ -608,8 +658,9 @@ fn msm_bigint<V: VariableBaseMSM>(
608658
mut bases: &[V::MulBase],
609659
mut scalars: &[<V::ScalarField as PrimeField>::BigInt],
610660
num_bits: usize,
661+
serial: bool,
611662
) -> V {
612-
if preamble(&mut bases, &mut scalars).is_none() {
663+
if preamble(&mut bases, &mut scalars, serial).is_none() {
613664
return V::zero();
614665
}
615666
let size = scalars.len();
@@ -701,9 +752,9 @@ fn msm_bigint<V: VariableBaseMSM>(
701752
})
702753
}
703754

704-
fn msm_serial<V: VariableBaseMSM>(
755+
fn msm_serial<V: VariableBaseMSM, S: Into<u64> + Copy + Send + Sync>(
705756
bases: &[V::MulBase],
706-
scalars: &[impl Into<u64> + Copy + Send + Sync],
757+
scalars: &[S],
707758
) -> V {
708759
let c = if bases.len() < 32 {
709760
3
@@ -717,7 +768,7 @@ fn msm_serial<V: VariableBaseMSM>(
717768
// We divide up the bits 0..num_bits into windows of size `c`, and
718769
// in parallel process each such window.
719770
let two_to_c = 1 << c;
720-
let window_sums: Vec<_> = (0..(core::mem::size_of::<u64>() * 8))
771+
let window_sums: Vec<_> = (0..(core::mem::size_of::<S>() * 8))
721772
.step_by(c)
722773
.map(|w_start| {
723774
let mut res = zero;

0 commit comments

Comments
 (0)