Skip to content

Commit 30ec1f8

Browse files
committed
Better handling of set_bit for negative numbers
1 parent 49a14b5 commit 30ec1f8

File tree

2 files changed

+100
-41
lines changed

2 files changed

+100
-41
lines changed

src/bigint.rs

Lines changed: 78 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3268,10 +3268,10 @@ impl BigInt {
32683268
true
32693269
} else {
32703270
let trailing_zeros = self.data.trailing_zeros().unwrap();
3271-
match bit.cmp(&trailing_zeros) {
3272-
Ordering::Less => false,
3273-
Ordering::Equal => true,
3274-
Ordering::Greater => !self.data.bit(bit),
3271+
match Ord::cmp(&bit, &trailing_zeros) {
3272+
Less => false,
3273+
Equal => true,
3274+
Greater => !self.data.bit(bit),
32753275
}
32763276
}
32773277
} else {
@@ -3289,7 +3289,7 @@ impl BigInt {
32893289
self.data.set_bit(bit, true);
32903290
self.sign = Sign::Plus;
32913291
} else {
3292-
// clearing a bit for zero is a no-op
3292+
// Clearing a bit for zero is a no-op
32933293
}
32943294
}
32953295
Sign::Minus => {
@@ -3299,49 +3299,87 @@ impl BigInt {
32993299
self.data.set_bit(bit, true);
33003300
}
33013301
} else {
3302+
// If the Uint number is
3303+
// ... 0 x 1 0 ... 0
3304+
// then the two's complement is
3305+
// ... 1 !x 1 0 ... 0
3306+
// |-- bit at position 'trailing_zeros'
3307+
// where !x is obtained from x by flipping each bit
33023308
let trailing_zeros = self.data.trailing_zeros().unwrap();
3303-
if bit > trailing_zeros {
3304-
self.data.set_bit(bit, !value);
3305-
} else if bit < trailing_zeros && !value {
3306-
// bit is already cleared
3307-
} else if bit == trailing_zeros && value {
3308-
// bit is already set
3309-
} else {
3310-
// general case
3311-
let bit_index = (bit / bits_per_digit).to_usize().unwrap();
3312-
let bit_mask = (1 as BigDigit) << (bit % bits_per_digit);
3313-
let mut carry_in = 1;
3314-
let mut carry_out = 1;
3315-
let mut digit_iter = self.digits_mut().iter_mut().skip(bit_index);
3316-
3317-
let digit = digit_iter.next().unwrap();
3318-
let twos_in = negate_carry(*digit, &mut carry_in);
3319-
let twos_out = if value {
3320-
// set bit
3321-
twos_in | bit_mask
3322-
} else {
3323-
// clear bit
3324-
twos_in & !bit_mask
3325-
};
3326-
*digit = negate_carry(twos_out, &mut carry_out);
3327-
3328-
for digit in digit_iter {
3329-
if carry_in == 0 && carry_out == 0 {
3330-
// no more digits will change
3331-
break;
3309+
match Ord::cmp(&bit, &trailing_zeros) {
3310+
Less => {
3311+
if value {
3312+
// We need to flip each bit from position 'bit' to 'trailing_zeros', both inclusive
3313+
// ... 1 !x 1 0 ... 0 ... 0
3314+
// |-- bit at position 'bit'
3315+
// |-- bit at position 'trailing_zeros'
3316+
// bit_mask: 1 1 ... 1 0 .. 0
3317+
// We do this by xor'ing with the bit_mask
3318+
let index_lo = (bit / bits_per_digit).to_usize().unwrap();
3319+
let index_hi =
3320+
(trailing_zeros / bits_per_digit).to_usize().unwrap();
3321+
let bit_mask_lo = BigDigit::MAX << (bit % bits_per_digit);
3322+
let bit_mask_hi = BigDigit::MAX
3323+
>> (bits_per_digit - 1 - (trailing_zeros % bits_per_digit));
3324+
let digits = self.digits_mut();
3325+
3326+
if index_lo == index_hi {
3327+
digits[index_lo] ^= bit_mask_lo & bit_mask_hi;
3328+
} else {
3329+
digits[index_lo] ^= bit_mask_lo;
3330+
for index in (index_lo + 1)..index_hi {
3331+
digits[index] = BigDigit::MAX;
3332+
}
3333+
digits[index_hi] ^= bit_mask_hi;
3334+
}
3335+
} else {
3336+
// Bit is already cleared
33323337
}
3333-
let twos = negate_carry(*digit, &mut carry_in);
3334-
*digit = negate_carry(twos, &mut carry_out);
33353338
}
3336-
3337-
if carry_out != 0 {
3338-
self.digits_mut().push(1 as BigDigit);
3339+
Equal => {
3340+
if value {
3341+
// Bit is already set
3342+
} else {
3343+
// Clearing the bit at position `trailing_zeros` is the only non-trivial
3344+
// case and is dealt with by doing similarly to what `bitand_neg_pos`
3345+
// does, except we start at digit `bit_index`; all digits below `bit_index`
3346+
// are guaranteed to be zero, so initially we must have
3347+
// `carry_in` = `carry_out` = 1
3348+
let bit_index = (bit / bits_per_digit).to_usize().unwrap();
3349+
let bit_mask = (1 as BigDigit) << (bit % bits_per_digit);
3350+
let mut digit_iter = self.digits_mut().iter_mut().skip(bit_index);
3351+
let mut carry_in = 1;
3352+
let mut carry_out = 1;
3353+
3354+
let digit = digit_iter.next().unwrap();
3355+
let twos_in = negate_carry(*digit, &mut carry_in);
3356+
let twos_out = twos_in & !bit_mask;
3357+
*digit = negate_carry(twos_out, &mut carry_out);
3358+
3359+
for digit in digit_iter {
3360+
if carry_in == 0 && carry_out == 0 {
3361+
// Exit the loop since no more digits can change
3362+
break;
3363+
}
3364+
let twos = negate_carry(*digit, &mut carry_in);
3365+
*digit = negate_carry(twos, &mut carry_out);
3366+
}
3367+
3368+
if carry_out != 0 {
3369+
// All digits have been traversed and there is a carry
3370+
debug_assert_eq!(carry_in, 0);
3371+
self.digits_mut().push(1);
3372+
}
3373+
}
3374+
}
3375+
Greater => {
3376+
self.data.set_bit(bit, !value);
33393377
}
33403378
}
33413379
}
33423380
}
33433381
}
3344-
// the top bit may have been cleared, so normalize
3382+
// The top bit may have been cleared, so normalize
33453383
self.normalize();
33463384
}
33473385
}

tests/bigint.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1328,16 +1328,26 @@ fn test_bit() {
13281328

13291329
#[test]
13301330
fn test_set_bit() {
1331-
let mut x = BigInt::zero();
1331+
let mut x: BigInt;
1332+
1333+
// zero
1334+
x = BigInt::zero();
13321335
x.set_bit(200, true);
13331336
assert_eq!(x, BigInt::one() << 200);
1337+
x = BigInt::zero();
1338+
x.set_bit(200, false);
1339+
assert_eq!(x, BigInt::zero());
1340+
1341+
// positive numbers
1342+
x = BigInt::from_biguint(Plus, BigUint::one() << 200);
13341343
x.set_bit(10, true);
13351344
x.set_bit(200, false);
13361345
assert_eq!(x, BigInt::one() << 10);
13371346
x.set_bit(10, false);
13381347
x.set_bit(5, false);
13391348
assert_eq!(x, BigInt::zero());
13401349

1350+
// negative numbers
13411351
x = BigInt::from(-12i8);
13421352
x.set_bit(200, true);
13431353
assert_eq!(x, BigInt::from(-12i8));
@@ -1359,6 +1369,13 @@ fn test_set_bit() {
13591369
x.set_bit(200, true);
13601370
assert_eq!(x, BigInt::from(-12i8));
13611371

1372+
x = BigInt::from_biguint(Minus, BigUint::one() << 30);
1373+
x.set_bit(10, true);
1374+
assert_eq!(
1375+
x,
1376+
BigInt::from_biguint(Minus, (BigUint::one() << 30) - (BigUint::one() << 10))
1377+
);
1378+
13621379
x = BigInt::from_biguint(Minus, BigUint::one() << 200);
13631380
x.set_bit(40, true);
13641381
assert_eq!(
@@ -1376,4 +1393,8 @@ fn test_set_bit() {
13761393
x = BigInt::from_biguint(Minus, (BigUint::one() << 63) | (BigUint::one() << 62));
13771394
x.set_bit(62, false);
13781395
assert_eq!(x, BigInt::from_biguint(Minus, BigUint::one() << 64));
1396+
1397+
x = BigInt::from_biguint(Minus, (BigUint::one() << 200) - BigUint::one());
1398+
x.set_bit(0, false);
1399+
assert_eq!(x, BigInt::from_biguint(Minus, BigUint::one() << 200));
13791400
}

0 commit comments

Comments
 (0)