Skip to content

Commit 42c9c9a

Browse files
committed
Enable allocation reuse in scalar multiplication
1 parent d4015d9 commit 42c9c9a

File tree

3 files changed

+144
-86
lines changed

3 files changed

+144
-86
lines changed

src/bigint/multiplication.rs

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,49 @@ impl Mul<Sign> for Sign {
2121
}
2222
}
2323

24-
forward_all_binop_to_ref_ref!(impl Mul for BigInt, mul);
25-
26-
impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt {
27-
type Output = BigInt;
28-
29-
#[inline]
30-
fn mul(self, other: &BigInt) -> BigInt {
31-
BigInt::from_biguint(self.sign * other.sign, &self.data * &other.data)
32-
}
24+
macro_rules! impl_mul {
25+
($(impl<$($a:lifetime),*> Mul<$Other:ty> for $Self:ty;)*) => {$(
26+
impl<$($a),*> Mul<$Other> for $Self {
27+
type Output = BigInt;
28+
29+
#[inline]
30+
fn mul(self, other: $Other) -> BigInt {
31+
// automatically match value/ref
32+
let BigInt { data: x, .. } = self;
33+
let BigInt { data: y, .. } = other;
34+
BigInt::from_biguint(self.sign * other.sign, x * y)
35+
}
36+
}
37+
)*}
38+
}
39+
impl_mul! {
40+
impl<> Mul<BigInt> for BigInt;
41+
impl<'b> Mul<&'b BigInt> for BigInt;
42+
impl<'a> Mul<BigInt> for &'a BigInt;
43+
impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt;
44+
}
45+
46+
macro_rules! impl_mul_assign {
47+
($(impl<$($a:lifetime),*> MulAssign<$Other:ty> for BigInt;)*) => {$(
48+
impl<$($a),*> MulAssign<$Other> for BigInt {
49+
#[inline]
50+
fn mul_assign(&mut self, other: $Other) {
51+
// automatically match value/ref
52+
let BigInt { data: y, .. } = other;
53+
self.data *= y;
54+
if self.data.is_zero() {
55+
self.sign = NoSign;
56+
} else {
57+
self.sign = self.sign * other.sign;
58+
}
59+
}
60+
}
61+
)*}
3362
}
34-
35-
impl<'a> MulAssign<&'a BigInt> for BigInt {
36-
#[inline]
37-
fn mul_assign(&mut self, other: &BigInt) {
38-
*self = &*self * other;
39-
}
63+
impl_mul_assign! {
64+
impl<> MulAssign<BigInt> for BigInt;
65+
impl<'a> MulAssign<&'a BigInt> for BigInt;
4066
}
41-
forward_val_assign!(impl MulAssign for BigInt, mul_assign);
4267

4368
promote_all_scalars!(impl Mul for BigInt, mul);
4469
promote_all_scalars_assign!(impl MulAssign for BigInt, mul_assign);

src/biguint/multiplication.rs

Lines changed: 100 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::{BigInt, UsizePromotion};
1111
use core::cmp::Ordering;
1212
use core::iter::Product;
1313
use core::ops::{Mul, MulAssign};
14-
use num_traits::{CheckedMul, One, Zero};
14+
use num_traits::{CheckedMul, FromPrimitive, One, Zero};
1515

1616
#[inline]
1717
pub(super) fn mac_with_carry(
@@ -155,28 +155,28 @@ fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
155155

156156
// We reuse the same BigUint for all the intermediate multiplies and have to size p
157157
// appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
158-
let len = x1.len() + y1.len() + 1;
158+
let len = x1.len() + y1.len();
159159
let mut p = BigUint { data: vec![0; len] };
160160

161161
// p2 = x1 * y1
162-
mac3(&mut p.data[..], x1, y1);
162+
mac3(&mut p.data, x1, y1);
163163

164164
// Not required, but the adds go faster if we drop any unneeded 0s from the end:
165165
p.normalize();
166166

167-
add2(&mut acc[b..], &p.data[..]);
168-
add2(&mut acc[b * 2..], &p.data[..]);
167+
add2(&mut acc[b..], &p.data);
168+
add2(&mut acc[b * 2..], &p.data);
169169

170170
// Zero out p before the next multiply:
171171
p.data.truncate(0);
172172
p.data.resize(len, 0);
173173

174174
// p0 = x0 * y0
175-
mac3(&mut p.data[..], x0, y0);
175+
mac3(&mut p.data, x0, y0);
176176
p.normalize();
177177

178-
add2(&mut acc[..], &p.data[..]);
179-
add2(&mut acc[b..], &p.data[..]);
178+
add2(acc, &p.data);
179+
add2(&mut acc[b..], &p.data);
180180

181181
// p1 = (x1 - x0) * (y1 - y0)
182182
// We do this one last, since it may be negative and acc can't ever be negative:
@@ -188,13 +188,13 @@ fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
188188
p.data.truncate(0);
189189
p.data.resize(len, 0);
190190

191-
mac3(&mut p.data[..], &j0.data[..], &j1.data[..]);
191+
mac3(&mut p.data, &j0.data, &j1.data);
192192
p.normalize();
193193

194-
sub2(&mut acc[b..], &p.data[..]);
194+
sub2(&mut acc[b..], &p.data);
195195
}
196196
Minus => {
197-
mac3(&mut acc[b..], &j0.data[..], &j1.data[..]);
197+
mac3(&mut acc[b..], &j0.data, &j1.data);
198198
}
199199
NoSign => (),
200200
}
@@ -321,25 +321,41 @@ fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
321321
}
322322

323323
fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
324-
let len = x.len() + y.len() + 1;
324+
let len = x.len() + y.len();
325325
let mut prod = BigUint { data: vec![0; len] };
326326

327-
mac3(&mut prod.data[..], x, y);
327+
mac3(&mut prod.data, x, y);
328328
prod.normalized()
329329
}
330330

331-
fn scalar_mul(a: &mut [BigDigit], b: BigDigit) -> BigDigit {
332-
let mut carry = 0;
333-
for a in a.iter_mut() {
334-
*a = mul_with_carry(*a, b, &mut carry);
331+
fn scalar_mul(a: &mut BigUint, b: BigDigit) {
332+
match b {
333+
0 => a.set_zero(),
334+
1 => {}
335+
_ => {
336+
if b.is_power_of_two() {
337+
*a <<= b.trailing_zeros();
338+
} else {
339+
let mut carry = 0;
340+
for a in a.data.iter_mut() {
341+
*a = mul_with_carry(*a, b, &mut carry);
342+
}
343+
if carry != 0 {
344+
a.data.push(carry as BigDigit);
345+
}
346+
}
347+
}
335348
}
336-
carry as BigDigit
337349
}
338350

339351
fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
340352
// Normalize:
341-
a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
342-
b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
353+
if let Some(&0) = a.last() {
354+
a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
355+
}
356+
if let Some(&0) = b.last() {
357+
b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
358+
}
343359

344360
match cmp_slice(a, b) {
345361
Ordering::Greater => {
@@ -356,22 +372,55 @@ fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
356372
}
357373
}
358374

359-
forward_all_binop_to_ref_ref!(impl Mul for BigUint, mul);
360-
forward_val_assign!(impl MulAssign for BigUint, mul_assign);
361-
362-
impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint {
363-
type Output = BigUint;
375+
macro_rules! impl_mul {
376+
($(impl<$($a:lifetime),*> Mul<$Other:ty> for $Self:ty;)*) => {$(
377+
impl<$($a),*> Mul<$Other> for $Self {
378+
type Output = BigUint;
379+
380+
#[inline]
381+
fn mul(self, other: $Other) -> BigUint {
382+
match (&*self.data, &*other.data) {
383+
// multiply by zero
384+
(&[], _) | (_, &[]) => BigUint::zero(),
385+
// multiply by a scalar
386+
(_, &[digit]) => self * digit,
387+
(&[digit], _) => other * digit,
388+
// full multiplication
389+
(x, y) => mul3(x, y),
390+
}
391+
}
392+
}
393+
)*}
394+
}
395+
impl_mul! {
396+
impl<> Mul<BigUint> for BigUint;
397+
impl<'b> Mul<&'b BigUint> for BigUint;
398+
impl<'a> Mul<BigUint> for &'a BigUint;
399+
impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint;
400+
}
364401

365-
#[inline]
366-
fn mul(self, other: &BigUint) -> BigUint {
367-
mul3(&self.data[..], &other.data[..])
368-
}
402+
macro_rules! impl_mul_assign {
403+
($(impl<$($a:lifetime),*> MulAssign<$Other:ty> for BigUint;)*) => {$(
404+
impl<$($a),*> MulAssign<$Other> for BigUint {
405+
#[inline]
406+
fn mul_assign(&mut self, other: $Other) {
407+
match (&*self.data, &*other.data) {
408+
// multiply by zero
409+
(&[], _) => {},
410+
(_, &[]) => self.set_zero(),
411+
// multiply by a scalar
412+
(_, &[digit]) => *self *= digit,
413+
(&[digit], _) => *self = other * digit,
414+
// full multiplication
415+
(x, y) => *self = mul3(x, y),
416+
}
417+
}
418+
}
419+
)*}
369420
}
370-
impl<'a> MulAssign<&'a BigUint> for BigUint {
371-
#[inline]
372-
fn mul_assign(&mut self, other: &'a BigUint) {
373-
*self = &*self * other
374-
}
421+
impl_mul_assign! {
422+
impl<> MulAssign<BigUint> for BigUint;
423+
impl<'a> MulAssign<&'a BigUint> for BigUint;
375424
}
376425

377426
promote_unsigned_scalars!(impl Mul for BigUint, mul);
@@ -392,14 +441,7 @@ impl Mul<u32> for BigUint {
392441
impl MulAssign<u32> for BigUint {
393442
#[inline]
394443
fn mul_assign(&mut self, other: u32) {
395-
if other == 0 {
396-
self.data.clear();
397-
} else {
398-
let carry = scalar_mul(&mut self.data[..], other as BigDigit);
399-
if carry != 0 {
400-
self.data.push(carry);
401-
}
402-
}
444+
scalar_mul(self, other as BigDigit);
403445
}
404446
}
405447

@@ -416,27 +458,18 @@ impl MulAssign<u64> for BigUint {
416458
#[cfg(not(u64_digit))]
417459
#[inline]
418460
fn mul_assign(&mut self, other: u64) {
419-
if other == 0 {
420-
self.data.clear();
421-
} else if other <= u64::from(BigDigit::max_value()) {
422-
*self *= other as BigDigit
461+
if let Some(other) = BigDigit::from_u64(other) {
462+
scalar_mul(self, other);
423463
} else {
424464
let (hi, lo) = big_digit::from_doublebigdigit(other);
425-
*self = mul3(&self.data[..], &[lo, hi])
465+
*self = mul3(&self.data, &[lo, hi]);
426466
}
427467
}
428468

429469
#[cfg(u64_digit)]
430470
#[inline]
431471
fn mul_assign(&mut self, other: u64) {
432-
if other == 0 {
433-
self.data.clear();
434-
} else {
435-
let carry = scalar_mul(&mut self.data[..], other as BigDigit);
436-
if carry != 0 {
437-
self.data.push(carry);
438-
}
439-
}
472+
scalar_mul(self, other);
440473
}
441474
}
442475

@@ -454,26 +487,25 @@ impl MulAssign<u128> for BigUint {
454487
#[cfg(not(u64_digit))]
455488
#[inline]
456489
fn mul_assign(&mut self, other: u128) {
457-
if other == 0 {
458-
self.data.clear();
459-
} else if other <= u128::from(BigDigit::max_value()) {
460-
*self *= other as BigDigit
490+
if let Some(other) = BigDigit::from_u128(other) {
491+
scalar_mul(self, other);
461492
} else {
462-
let (a, b, c, d) = u32_from_u128(other);
463-
*self = mul3(&self.data[..], &[d, c, b, a])
493+
*self = match u32_from_u128(other) {
494+
(0, 0, c, d) => mul3(&self.data, &[d, c]),
495+
(0, b, c, d) => mul3(&self.data, &[d, c, b]),
496+
(a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
497+
};
464498
}
465499
}
466500

467501
#[cfg(u64_digit)]
468502
#[inline]
469503
fn mul_assign(&mut self, other: u128) {
470-
if other == 0 {
471-
self.data.clear();
472-
} else if other <= BigDigit::max_value() as u128 {
473-
*self *= other as BigDigit
504+
if let Some(other) = BigDigit::from_u128(other) {
505+
scalar_mul(self, other);
474506
} else {
475507
let (hi, lo) = big_digit::from_doublebigdigit(other);
476-
*self = mul3(&self.data[..], &[lo, hi])
508+
*self = mul3(&self.data, &[lo, hi]);
477509
}
478510
}
479511
}
@@ -502,6 +534,6 @@ fn test_sub_sign() {
502534
let a_i = BigInt::from(a.clone());
503535
let b_i = BigInt::from(b.clone());
504536

505-
assert_eq!(sub_sign_i(&a.data[..], &b.data[..]), &a_i - &b_i);
506-
assert_eq!(sub_sign_i(&b.data[..], &a.data[..]), &b_i - &a_i);
537+
assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
538+
assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
507539
}

src/biguint/power.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ macro_rules! pow_impl {
8585
exp >>= 1;
8686
base = &base * &base;
8787
if exp & 1 == 1 {
88-
acc = &acc * &base;
88+
acc *= &base;
8989
}
9090
}
9191
acc
@@ -185,7 +185,8 @@ fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> Big
185185
let mut unit = |exp_is_odd| {
186186
base = &base * &base % modulus;
187187
if exp_is_odd {
188-
acc = &acc * &base % modulus;
188+
acc *= &base;
189+
acc %= modulus;
189190
}
190191
};
191192

0 commit comments

Comments
 (0)