Skip to content

Commit 1d9ddc6

Browse files
committed
feat: addition/subtraction acceleration via inline assembly
1 parent a25836e commit 1d9ddc6

File tree

2 files changed

+202
-5
lines changed

2 files changed

+202
-5
lines changed

src/biguint/addition.rs

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use super::{BigUint, IntDigits};
2+
#[cfg(target_arch = "x86_64")]
3+
use std::arch::asm;
24

35
use crate::big_digit::{self, BigDigit};
46
use crate::UsizePromotion;
@@ -45,6 +47,96 @@ fn adc(carry: u8, lhs: BigDigit, rhs: BigDigit, out: &mut BigDigit) -> u8 {
4547
u8::from(b || d)
4648
}
4749

50+
/// Performs a part of the addition. Returns a tuple containing the carry state
51+
/// and the number of integers that were added
52+
///
53+
/// By using as many registers as possible, we treat digits 5 by 5
54+
#[cfg(target_arch = "x86_64")]
55+
unsafe fn schoolbook_add_assign_x86_64(
56+
lhs: *mut u64,
57+
rhs: *const u64,
58+
mut size: usize,
59+
) -> (bool, usize) {
60+
size /= 5;
61+
if size == 0 {
62+
return (false, 0);
63+
}
64+
65+
let mut c: u8;
66+
let mut idx = 0;
67+
68+
asm!(
69+
// Clear the carry flag
70+
"clc",
71+
72+
"3:",
73+
74+
// Copy a in registers
75+
"mov {a_tmp1}, qword ptr [{a} + 8*{idx}]",
76+
"mov {a_tmp2}, qword ptr [{a} + 8*{idx} + 8]",
77+
"mov {a_tmp3}, qword ptr [{a} + 8*{idx} + 16]",
78+
"mov {a_tmp4}, qword ptr [{a} + 8*{idx} + 24]",
79+
"mov {a_tmp5}, qword ptr [{a} + 8*{idx} + 32]",
80+
81+
// Copy b in registers
82+
"mov {b_tmp1}, qword ptr [{b} + 8*{idx}]",
83+
"mov {b_tmp2}, qword ptr [{b} + 8*{idx} + 8]",
84+
"mov {b_tmp3}, qword ptr [{b} + 8*{idx} + 16]",
85+
"mov {b_tmp4}, qword ptr [{b} + 8*{idx} + 24]",
86+
"mov {b_tmp5}, qword ptr [{b} + 8*{idx} + 32]",
87+
88+
// Perform the addition
89+
"adc {a_tmp1}, {b_tmp1}",
90+
"adc {a_tmp2}, {b_tmp2}",
91+
"adc {a_tmp3}, {b_tmp3}",
92+
"adc {a_tmp4}, {b_tmp4}",
93+
"adc {a_tmp5}, {b_tmp5}",
94+
95+
// Copy the return values
96+
"mov qword ptr [{a} + 8*{idx}], {a_tmp1}",
97+
"mov qword ptr [{a} + 8*{idx} + 8], {a_tmp2}",
98+
"mov qword ptr [{a} + 8*{idx} + 16], {a_tmp3}",
99+
"mov qword ptr [{a} + 8*{idx} + 24], {a_tmp4}",
100+
"mov qword ptr [{a} + 8*{idx} + 32], {a_tmp5}",
101+
102+
// Increment loop counter
103+
// `inc` and `dec` aren't modifying carry flag
104+
"inc {idx}",
105+
"inc {idx}",
106+
"inc {idx}",
107+
"inc {idx}",
108+
"inc {idx}",
109+
"dec {size}",
110+
"jnz 3b",
111+
112+
// Output carry flag and clear
113+
"setc {c}",
114+
"clc",
115+
116+
size = in(reg) size,
117+
a = in(reg) lhs,
118+
b = in(reg) rhs,
119+
c = lateout(reg_byte) c,
120+
idx = inout(reg) idx,
121+
122+
a_tmp1 = out(reg) _,
123+
a_tmp2 = out(reg) _,
124+
a_tmp3 = out(reg) _,
125+
a_tmp4 = out(reg) _,
126+
a_tmp5 = out(reg) _,
127+
128+
b_tmp1 = out(reg) _,
129+
b_tmp2 = out(reg) _,
130+
b_tmp3 = out(reg) _,
131+
b_tmp4 = out(reg) _,
132+
b_tmp5 = out(reg) _,
133+
134+
options(nostack),
135+
);
136+
137+
(c > 0, idx)
138+
}
139+
48140
/// Two argument addition of raw slices, `a += b`, returning the carry.
49141
///
50142
/// This is used when the data `Vec` might need to resize to push a non-zero carry, so we perform
@@ -55,10 +147,17 @@ fn adc(carry: u8, lhs: BigDigit, rhs: BigDigit, out: &mut BigDigit) -> u8 {
55147
pub(super) fn __add2(a: &mut [BigDigit], b: &[BigDigit]) -> BigDigit {
56148
debug_assert!(a.len() >= b.len());
57149

58-
let mut carry = 0;
59150
let (a_lo, a_hi) = a.split_at_mut(b.len());
60151

61-
for (a, b) in a_lo.iter_mut().zip(b) {
152+
// On x86_64 machine, perform most of the addition via inline assembly
153+
#[cfg(target_arch = "x86_64")]
154+
let (c, done) = unsafe { schoolbook_add_assign_x86_64(a_lo.as_mut_ptr(), b.as_ptr(), b.len()) };
155+
#[cfg(not(target_arch = "x86_64"))]
156+
let (c, done) = (false, 0);
157+
158+
let mut carry = c as u8;
159+
160+
for (a, b) in a_lo[done..].iter_mut().zip(b[done..].iter()) {
62161
carry = adc(carry, *a, *b, a);
63162
}
64163

src/biguint/subtraction.rs

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use super::BigUint;
2+
#[cfg(target_arch = "x86_64")]
3+
use std::arch::asm;
24

35
use crate::big_digit::{self, BigDigit};
46
use crate::UsizePromotion;
@@ -45,14 +47,110 @@ fn sbb(borrow: u8, lhs: BigDigit, rhs: BigDigit, out: &mut BigDigit) -> u8 {
4547
u8::from(b || d)
4648
}
4749

48-
pub(super) fn sub2(a: &mut [BigDigit], b: &[BigDigit]) {
49-
let mut borrow = 0;
50+
/// Performs a part of the subtraction. Returns a tuple containing the carry state
51+
/// and the number of integers that were subtracted
52+
///
53+
/// By using as many registers as possible, we treat digits 5 by 5
54+
#[cfg(target_arch = "x86_64")]
55+
unsafe fn schoolbook_sub_assign_x86_64(
56+
lhs: *mut u64,
57+
rhs: *const u64,
58+
mut size: usize,
59+
) -> (bool, usize) {
60+
size /= 5;
61+
if size == 0 {
62+
return (false, 0);
63+
}
64+
65+
let mut c: u8;
66+
let mut idx = 0;
67+
68+
asm!(
69+
// Clear carry flag
70+
"clc",
71+
72+
"3:",
73+
74+
// Copy a in registers
75+
"mov {a_tmp1}, qword ptr [{a} + 8*{idx}]",
76+
"mov {a_tmp2}, qword ptr [{a} + 8*{idx} + 8]",
77+
"mov {a_tmp3}, qword ptr [{a} + 8*{idx} + 16]",
78+
"mov {a_tmp4}, qword ptr [{a} + 8*{idx} + 24]",
79+
"mov {a_tmp5}, qword ptr [{a} + 8*{idx} + 32]",
80+
81+
// Copy b in registers
82+
"mov {b_tmp1}, qword ptr [{b} + 8*{idx}]",
83+
"mov {b_tmp2}, qword ptr [{b} + 8*{idx} + 8]",
84+
"mov {b_tmp3}, qword ptr [{b} + 8*{idx} + 16]",
85+
"mov {b_tmp4}, qword ptr [{b} + 8*{idx} + 24]",
86+
"mov {b_tmp5}, qword ptr [{b} + 8*{idx} + 32]",
87+
88+
// Perform the subtraction
89+
"sbb {a_tmp1}, {b_tmp1}",
90+
"sbb {a_tmp2}, {b_tmp2}",
91+
"sbb {a_tmp3}, {b_tmp3}",
92+
"sbb {a_tmp4}, {b_tmp4}",
93+
"sbb {a_tmp5}, {b_tmp5}",
94+
95+
// Copy the return values
96+
"mov qword ptr [{a} + 8*{idx}], {a_tmp1}",
97+
"mov qword ptr [{a} + 8*{idx} + 8], {a_tmp2}",
98+
"mov qword ptr [{a} + 8*{idx} + 16], {a_tmp3}",
99+
"mov qword ptr [{a} + 8*{idx} + 24], {a_tmp4}",
100+
"mov qword ptr [{a} + 8*{idx} + 32], {a_tmp5}",
101+
102+
// Increment loop counter
103+
// `inc` and `dec` aren't modifying carry flag
104+
"inc {idx}",
105+
"inc {idx}",
106+
"inc {idx}",
107+
"inc {idx}",
108+
"inc {idx}",
109+
"dec {size}",
110+
"jnz 3b",
111+
112+
// Output carry flag and clear
113+
"setc {c}",
114+
"clc",
115+
116+
size = in(reg) size,
117+
a = in(reg) lhs,
118+
b = in(reg) rhs,
119+
c = lateout(reg_byte) c,
120+
idx = inout(reg) idx,
121+
122+
a_tmp1 = out(reg) _,
123+
a_tmp2 = out(reg) _,
124+
a_tmp3 = out(reg) _,
125+
a_tmp4 = out(reg) _,
126+
a_tmp5 = out(reg) _,
127+
128+
b_tmp1 = out(reg) _,
129+
b_tmp2 = out(reg) _,
130+
b_tmp3 = out(reg) _,
131+
b_tmp4 = out(reg) _,
132+
b_tmp5 = out(reg) _,
133+
134+
options(nostack),
135+
);
136+
137+
(c > 0, idx)
138+
}
50139

140+
pub(super) fn sub2(a: &mut [BigDigit], b: &[BigDigit]) {
51141
let len = Ord::min(a.len(), b.len());
52142
let (a_lo, a_hi) = a.split_at_mut(len);
53143
let (b_lo, b_hi) = b.split_at(len);
54144

55-
for (a, b) in a_lo.iter_mut().zip(b_lo) {
145+
// On x86_64 machine, perform most of the subtraction via inline assembly
146+
#[cfg(target_arch = "x86_64")]
147+
let (b, done) = unsafe { schoolbook_sub_assign_x86_64(a_lo.as_mut_ptr(), b_lo.as_ptr(), len) };
148+
#[cfg(not(target_arch = "x86_64"))]
149+
let (b, done) = (false, 0);
150+
151+
let mut borrow = b as u8;
152+
153+
for (a, b) in a_lo[done..].iter_mut().zip(b_lo[done..].iter()) {
56154
borrow = sbb(borrow, *a, *b, a);
57155
}
58156

0 commit comments

Comments
 (0)