Skip to content

Commit 87fe608

Browse files
committed
Make Shr for negative BigInt round down, like primitives do
Primitive integers always round down when shifting right, but `BigInt` was effectively rounding toward zero, because it just kept its sign and used the `BigUint` magnitude rounded down (always toward zero). Now we adjust the result of shifting negative values, and explicitly test that it matches the result for primitive integers.
1 parent 5e389ca commit 87fe608

File tree

3 files changed

+58
-18
lines changed

3 files changed

+58
-18
lines changed

src/bigint.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,23 @@ impl<'a> Shl<usize> for &'a BigInt {
228228
}
229229
}
230230

231+
// Negative values need a rounding adjustment if there are any ones in the
232+
// bits that are getting shifted out.
233+
fn shr_round_down(i: &BigInt, rhs: usize) -> bool {
234+
i.is_negative() &&
235+
biguint::trailing_zeros(&i.data)
236+
.map(|n| n < rhs)
237+
.unwrap_or(false)
238+
}
239+
231240
impl Shr<usize> for BigInt {
232241
type Output = BigInt;
233242

234243
#[inline]
235244
fn shr(self, rhs: usize) -> BigInt {
236-
BigInt::from_biguint(self.sign, self.data >> rhs)
245+
let round_down = shr_round_down(&self, rhs);
246+
let data = self.data >> rhs;
247+
BigInt::from_biguint(self.sign, if round_down { data + 1u8 } else { data })
237248
}
238249
}
239250

@@ -242,7 +253,9 @@ impl<'a> Shr<usize> for &'a BigInt {
242253

243254
#[inline]
244255
fn shr(self, rhs: usize) -> BigInt {
245-
BigInt::from_biguint(self.sign, &self.data >> rhs)
256+
let round_down = shr_round_down(&self, rhs);
257+
let data = &self.data >> rhs;
258+
BigInt::from_biguint(self.sign, if round_down { data + 1u8 } else { data })
246259
}
247260
}
248261

src/biguint.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,11 @@ impl Integer for BigUint {
944944
/// The result is always positive.
945945
#[inline]
946946
fn gcd(&self, other: &Self) -> Self {
947+
#[inline]
948+
fn twos(x: &BigUint) -> usize {
949+
trailing_zeros(x).unwrap_or(0)
950+
}
951+
947952
// Stein's algorithm
948953
if self.is_zero() {
949954
return other.clone();
@@ -955,17 +960,14 @@ impl Integer for BigUint {
955960
let mut n = other.clone();
956961

957962
// find common factors of 2
958-
let shift = cmp::min(
959-
n.trailing_zeros(),
960-
m.trailing_zeros()
961-
);
963+
let shift = cmp::min(twos(&n), twos(&m));
962964

963965
// divide m and n by 2 until odd
964966
// m inside loop
965-
n >>= n.trailing_zeros();
967+
n >>= twos(&n);
966968

967969
while !m.is_zero() {
968-
m >>= m.trailing_zeros();
970+
m >>= twos(&m);
969971
if n > m { mem::swap(&mut n, &mut m) }
970972
m -= &n;
971973
}
@@ -1628,16 +1630,6 @@ impl BigUint {
16281630
return self.data.len() * big_digit::BITS - zeros as usize;
16291631
}
16301632

1631-
// self is assumed to be normalized
1632-
fn trailing_zeros(&self) -> usize {
1633-
self.data
1634-
.iter()
1635-
.enumerate()
1636-
.find(|&(_, &digit)| digit != 0)
1637-
.map(|(i, digit)| i * big_digit::BITS + digit.trailing_zeros() as usize)
1638-
.unwrap_or(0)
1639-
}
1640-
16411633
/// Strips off trailing zero bigdigits - comparisons require the last element in the vector to
16421634
/// be nonzero.
16431635
#[inline]
@@ -1689,6 +1681,16 @@ impl BigUint {
16891681
}
16901682
}
16911683

1684+
/// Returns the number of least-significant bits that are zero,
1685+
/// or `None` if the entire number is zero.
1686+
pub fn trailing_zeros(u: &BigUint) -> Option<usize> {
1687+
u.data
1688+
.iter()
1689+
.enumerate()
1690+
.find(|&(_, &digit)| digit != 0)
1691+
.map(|(i, digit)| i * big_digit::BITS + digit.trailing_zeros() as usize)
1692+
}
1693+
16921694
#[cfg(feature = "serde")]
16931695
impl serde::Serialize for BigUint {
16941696
fn serialize<S>(&self, serializer: &mut S) -> Result<(), S::Error>

src/tests/bigint.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,3 +1192,28 @@ fn test_negative_rand_range() {
11921192
// Switching u and l should fail:
11931193
let _n: BigInt = rng.gen_bigint_range(&u, &l);
11941194
}
1195+
1196+
#[test]
1197+
fn test_negative_shr() {
1198+
assert_eq!(BigInt::from(-1) >> 1, BigInt::from(-1));
1199+
assert_eq!(BigInt::from(-2) >> 1, BigInt::from(-1));
1200+
assert_eq!(BigInt::from(-3) >> 1, BigInt::from(-2));
1201+
assert_eq!(BigInt::from(-3) >> 2, BigInt::from(-1));
1202+
}
1203+
1204+
#[test]
1205+
fn test_random_shr() {
1206+
use rand::Rng;
1207+
let mut rng = thread_rng();
1208+
1209+
for p in rng.gen_iter::<i64>().take(1000) {
1210+
let big = BigInt::from(p);
1211+
let bigger = &big << 1000;
1212+
assert_eq!(&bigger >> 1000, big);
1213+
for i in 0..64 {
1214+
let answer = BigInt::from(p >> i);
1215+
assert_eq!(&big >> i, answer);
1216+
assert_eq!(&bigger >> (1000 + i), answer);
1217+
}
1218+
}
1219+
}

0 commit comments

Comments
 (0)