Skip to content

Commit ed22935

Browse files
committed
more accurate sqrt function
1 parent 91fdc06 commit ed22935

File tree

1 file changed

+121
-30
lines changed

1 file changed

+121
-30
lines changed

src/lib.rs

Lines changed: 121 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -281,40 +281,87 @@ impl<T: Float> Complex<T> {
281281
///
282282
/// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`.
283283
#[inline]
284-
pub fn sqrt(self) -> Self {
285-
if self.im.is_zero() {
286-
if self.re.is_sign_positive() {
287-
// simple positive real √r, and copy `im` for its sign
288-
Self::new(self.re.sqrt(), self.im)
284+
pub fn sqrt(mut self) -> Self {
285+
// complex sqrt algorithm based on the algorithm from
286+
// dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks
287+
// to increase accuracy. Compared to a naive implementationt that
288+
// reuses the complex exp/ln implementations this algorithm has better
289+
// accuarcy since both (real) sqrt and (real) hypot are garunteed to
290+
// round perfectly. It's also faster since this implementation requires
291+
// less transcendental functions and those it does use (sqrt/hypto) are
292+
// faster comparted to exp/sin/cos.
293+
//
294+
// The musl libc implementation was referenced while implementing the
295+
// algorithm here:
296+
// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c
297+
298+
// TODO: rounding for very tiny subnormal numbers isn't perfect yet so
299+
// the assert shown fails in the very worst case this leads to about
300+
// 10% accuracy loss (see example below). As the magnitude increase the
301+
// error quickly drops to basically zero.
302+
//
303+
// glibc handles that (but other implementations like musl and numpy do
304+
// not) by upscaling very small values. That upscaling (and particularly
305+
// it's reversal) are weird and hard to understand (and rely on mantissa
306+
// bit size which we can't get out of the trait). In general the glibc
307+
// implementation is ever so subtley different and I wouldn't want to
308+
// introduce bugs by trying to adapt the underflow handling.
309+
//
310+
// assert_eq!(
311+
// Complex64::new(5.212e-324, 5.212e-324).sqrt(),
312+
// Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162)
313+
// );
314+
315+
// specical cases for correct nan/inf handling
316+
// see https://en.cppreference.com/w/c/numeric/complex/csqrt
317+
318+
if self.re.is_zero() && self.im.is_zero() {
319+
// 0 +/- 0 i
320+
return Self::new(T::zero(), self.im);
321+
}
322+
if self.im.is_infinite() {
323+
// inf +/- inf i
324+
return Self::new(T::infinity(), self.im);
325+
}
326+
if self.re.is_nan() {
327+
// nan + nan i
328+
return Self::new(self.re, T::nan());
329+
}
330+
if self.re.is_infinite() {
331+
// √(inf +/- NaN i) = inf +/- NaN i
332+
// √(inf +/- x i) = inf +/- 0 i
333+
// √(-inf +/- NaN i) = NaN +/- inf i
334+
// √(-inf +/- x i) = 0 +/- inf i
335+
336+
// if im is inf (or nan) this is nan, otherwise it's zero
337+
#[allow(clippy::eq_op)]
338+
let zero_or_nan = self.im - self.im;
339+
if self.re.is_sign_negative() {
340+
return Self::new(zero_or_nan.abs(), self.re.copysign(self.im));
289341
} else {
290-
// √(r e^(iπ)) = √r e^(iπ/2) = i√r
291-
// √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r
292-
let re = T::zero();
293-
let im = (-self.re).sqrt();
294-
if self.im.is_sign_positive() {
295-
Self::new(re, im)
296-
} else {
297-
Self::new(re, -im)
298-
}
299-
}
300-
} else if self.re.is_zero() {
301-
// √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2)
302-
// √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2)
303-
let one = T::one();
304-
let two = one + one;
305-
let x = (self.im.abs() / two).sqrt();
306-
if self.im.is_sign_positive() {
307-
Self::new(x, x)
308-
} else {
309-
Self::new(x, -x)
342+
return Self::new(self.re, zero_or_nan.copysign(self.im));
310343
}
344+
}
345+
let two = T::one() + T::one();
346+
let four = two + two;
347+
let overflow = T::max_value() / (T::one() + T::sqrt(two));
348+
let max_magnitude = self.re.abs().max(self.im.abs());
349+
let scale = max_magnitude >= overflow;
350+
if scale {
351+
self = self / four;
352+
}
353+
if self.re.is_sign_negative() {
354+
let tmp = ((-self.re + self.norm()) / two).sqrt();
355+
self.re = self.im.abs() / (two * tmp);
356+
self.im = tmp.copysign(self.im);
311357
} else {
312-
// formula: sqrt(r e^(it)) = sqrt(r) e^(it/2)
313-
let one = T::one();
314-
let two = one + one;
315-
let (r, theta) = self.to_polar();
316-
Self::from_polar(r.sqrt(), theta / two)
358+
self.re = ((self.re + self.norm()) / two).sqrt();
359+
self.im = self.im / (two * self.re);
360+
}
361+
if scale {
362+
self = self * two;
317363
}
364+
self
318365
}
319366

320367
/// Computes the principal value of the cube root of `self`.
@@ -2065,6 +2112,50 @@ pub(crate) mod test {
20652112
}
20662113
}
20672114

2115+
#[test]
2116+
fn test_sqrt_nan() {
2117+
assert!(close_naninf(
2118+
Complex64::new(f64::INFINITY, f64::NAN).sqrt(),
2119+
Complex64::new(f64::INFINITY, f64::NAN),
2120+
));
2121+
assert!(close_naninf(
2122+
Complex64::new(f64::NAN, f64::INFINITY).sqrt(),
2123+
Complex64::new(f64::INFINITY, f64::INFINITY),
2124+
));
2125+
assert!(close_naninf(
2126+
Complex64::new(f64::NEG_INFINITY, -f64::NAN).sqrt(),
2127+
Complex64::new(f64::NAN, f64::NEG_INFINITY),
2128+
));
2129+
assert!(close_naninf(
2130+
Complex64::new(f64::NEG_INFINITY, f64::NAN).sqrt(),
2131+
Complex64::new(f64::NAN, f64::INFINITY),
2132+
));
2133+
assert!(close_naninf(
2134+
Complex64::new(-0.0, 0.0).sqrt(),
2135+
Complex64::new(0.0, 0.0),
2136+
));
2137+
for x in (-100..100).map(f64::from) {
2138+
assert!(close_naninf(
2139+
Complex64::new(x, f64::INFINITY).sqrt(),
2140+
Complex64::new(f64::INFINITY, f64::INFINITY),
2141+
));
2142+
assert!(close_naninf(
2143+
Complex64::new(f64::NAN, x).sqrt(),
2144+
Complex64::new(f64::NAN, f64::NAN),
2145+
));
2146+
// √(inf + x i) = inf + 0 i
2147+
assert!(close_naninf(
2148+
Complex64::new(f64::INFINITY, x).sqrt(),
2149+
Complex64::new(f64::INFINITY, 0.0.copysign(x)),
2150+
));
2151+
// √(-inf + x i) = 0 + inf i
2152+
assert!(close_naninf(
2153+
Complex64::new(f64::NEG_INFINITY, x).sqrt(),
2154+
Complex64::new(0.0, f64::INFINITY.copysign(x)),
2155+
));
2156+
}
2157+
}
2158+
20682159
#[test]
20692160
fn test_cbrt() {
20702161
assert!(close(_0_0i.cbrt(), _0_0i));

0 commit comments

Comments
 (0)