@@ -281,40 +281,87 @@ impl<T: Float> Complex<T> {
281
281
///
282
282
/// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`.
283
283
#[ inline]
284
- pub fn sqrt ( self ) -> Self {
285
- if self . im . is_zero ( ) {
286
- if self . re . is_sign_positive ( ) {
287
- // simple positive real √r, and copy `im` for its sign
288
- Self :: new ( self . re . sqrt ( ) , self . im )
284
+ pub fn sqrt ( mut self ) -> Self {
285
+ // complex sqrt algorithm based on the algorithm from
286
+ // dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks
287
+ // to increase accuracy. Compared to a naive implementationt that
288
+ // reuses the complex exp/ln implementations this algorithm has better
289
+ // accuarcy since both (real) sqrt and (real) hypot are garunteed to
290
+ // round perfectly. It's also faster since this implementation requires
291
+ // less transcendental functions and those it does use (sqrt/hypto) are
292
+ // faster comparted to exp/sin/cos.
293
+ //
294
+ // The musl libc implementation was referenced while implementing the
295
+ // algorithm here:
296
+ // https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c
297
+
298
+ // TODO: rounding for very tiny subnormal numbers isn't perfect yet so
299
+ // the assert shown fails in the very worst case this leads to about
300
+ // 10% accuracy loss (see example below). As the magnitude increase the
301
+ // error quickly drops to basically zero.
302
+ //
303
+ // glibc handles that (but other implementations like musl and numpy do
304
+ // not) by upscaling very small values. That upscaling (and particularly
305
+ // it's reversal) are weird and hard to understand (and rely on mantissa
306
+ // bit size which we can't get out of the trait). In general the glibc
307
+ // implementation is ever so subtley different and I wouldn't want to
308
+ // introduce bugs by trying to adapt the underflow handling.
309
+ //
310
+ // assert_eq!(
311
+ // Complex64::new(5.212e-324, 5.212e-324).sqrt(),
312
+ // Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162)
313
+ // );
314
+
315
+ // specical cases for correct nan/inf handling
316
+ // see https://en.cppreference.com/w/c/numeric/complex/csqrt
317
+
318
+ if self . re . is_zero ( ) && self . im . is_zero ( ) {
319
+ // 0 +/- 0 i
320
+ return Self :: new ( T :: zero ( ) , self . im ) ;
321
+ }
322
+ if self . im . is_infinite ( ) {
323
+ // inf +/- inf i
324
+ return Self :: new ( T :: infinity ( ) , self . im ) ;
325
+ }
326
+ if self . re . is_nan ( ) {
327
+ // nan + nan i
328
+ return Self :: new ( self . re , T :: nan ( ) ) ;
329
+ }
330
+ if self . re . is_infinite ( ) {
331
+ // √(inf +/- NaN i) = inf +/- NaN i
332
+ // √(inf +/- x i) = inf +/- 0 i
333
+ // √(-inf +/- NaN i) = NaN +/- inf i
334
+ // √(-inf +/- x i) = 0 +/- inf i
335
+
336
+ // if im is inf (or nan) this is nan, otherwise it's zero
337
+ #[ allow( clippy:: eq_op) ]
338
+ let zero_or_nan = self . im - self . im ;
339
+ if self . re . is_sign_negative ( ) {
340
+ return Self :: new ( zero_or_nan. abs ( ) , self . re . copysign ( self . im ) ) ;
289
341
} else {
290
- // √(r e^(iπ)) = √r e^(iπ/2) = i√r
291
- // √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r
292
- let re = T :: zero ( ) ;
293
- let im = ( -self . re ) . sqrt ( ) ;
294
- if self . im . is_sign_positive ( ) {
295
- Self :: new ( re, im)
296
- } else {
297
- Self :: new ( re, -im)
298
- }
299
- }
300
- } else if self . re . is_zero ( ) {
301
- // √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2)
302
- // √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2)
303
- let one = T :: one ( ) ;
304
- let two = one + one;
305
- let x = ( self . im . abs ( ) / two) . sqrt ( ) ;
306
- if self . im . is_sign_positive ( ) {
307
- Self :: new ( x, x)
308
- } else {
309
- Self :: new ( x, -x)
342
+ return Self :: new ( self . re , zero_or_nan. copysign ( self . im ) ) ;
310
343
}
344
+ }
345
+ let two = T :: one ( ) + T :: one ( ) ;
346
+ let four = two + two;
347
+ let overflow = T :: max_value ( ) / ( T :: one ( ) + T :: sqrt ( two) ) ;
348
+ let max_magnitude = self . re . abs ( ) . max ( self . im . abs ( ) ) ;
349
+ let scale = max_magnitude >= overflow;
350
+ if scale {
351
+ self = self / four;
352
+ }
353
+ if self . re . is_sign_negative ( ) {
354
+ let tmp = ( ( -self . re + self . norm ( ) ) / two) . sqrt ( ) ;
355
+ self . re = self . im . abs ( ) / ( two * tmp) ;
356
+ self . im = tmp. copysign ( self . im ) ;
311
357
} else {
312
- // formula: sqrt(r e^(it )) = sqrt(r) e^(it/2)
313
- let one = T :: one ( ) ;
314
- let two = one + one ;
315
- let ( r , theta ) = self . to_polar ( ) ;
316
- Self :: from_polar ( r . sqrt ( ) , theta / two)
358
+ self . re = ( ( self . re + self . norm ( ) ) / two ) . sqrt ( ) ;
359
+ self . im = self . im / ( two * self . re ) ;
360
+ }
361
+ if scale {
362
+ self = self * two;
317
363
}
364
+ self
318
365
}
319
366
320
367
/// Computes the principal value of the cube root of `self`.
@@ -2065,6 +2112,50 @@ pub(crate) mod test {
2065
2112
}
2066
2113
}
2067
2114
2115
+ #[ test]
2116
+ fn test_sqrt_nan ( ) {
2117
+ assert ! ( close_naninf(
2118
+ Complex64 :: new( f64 :: INFINITY , f64 :: NAN ) . sqrt( ) ,
2119
+ Complex64 :: new( f64 :: INFINITY , f64 :: NAN ) ,
2120
+ ) ) ;
2121
+ assert ! ( close_naninf(
2122
+ Complex64 :: new( f64 :: NAN , f64 :: INFINITY ) . sqrt( ) ,
2123
+ Complex64 :: new( f64 :: INFINITY , f64 :: INFINITY ) ,
2124
+ ) ) ;
2125
+ assert ! ( close_naninf(
2126
+ Complex64 :: new( f64 :: NEG_INFINITY , -f64 :: NAN ) . sqrt( ) ,
2127
+ Complex64 :: new( f64 :: NAN , f64 :: NEG_INFINITY ) ,
2128
+ ) ) ;
2129
+ assert ! ( close_naninf(
2130
+ Complex64 :: new( f64 :: NEG_INFINITY , f64 :: NAN ) . sqrt( ) ,
2131
+ Complex64 :: new( f64 :: NAN , f64 :: INFINITY ) ,
2132
+ ) ) ;
2133
+ assert ! ( close_naninf(
2134
+ Complex64 :: new( -0.0 , 0.0 ) . sqrt( ) ,
2135
+ Complex64 :: new( 0.0 , 0.0 ) ,
2136
+ ) ) ;
2137
+ for x in ( -100 ..100 ) . map ( f64:: from) {
2138
+ assert ! ( close_naninf(
2139
+ Complex64 :: new( x, f64 :: INFINITY ) . sqrt( ) ,
2140
+ Complex64 :: new( f64 :: INFINITY , f64 :: INFINITY ) ,
2141
+ ) ) ;
2142
+ assert ! ( close_naninf(
2143
+ Complex64 :: new( f64 :: NAN , x) . sqrt( ) ,
2144
+ Complex64 :: new( f64 :: NAN , f64 :: NAN ) ,
2145
+ ) ) ;
2146
+ // √(inf + x i) = inf + 0 i
2147
+ assert ! ( close_naninf(
2148
+ Complex64 :: new( f64 :: INFINITY , x) . sqrt( ) ,
2149
+ Complex64 :: new( f64 :: INFINITY , 0.0 . copysign( x) ) ,
2150
+ ) ) ;
2151
+ // √(-inf + x i) = 0 + inf i
2152
+ assert ! ( close_naninf(
2153
+ Complex64 :: new( f64 :: NEG_INFINITY , x) . sqrt( ) ,
2154
+ Complex64 :: new( 0.0 , f64 :: INFINITY . copysign( x) ) ,
2155
+ ) ) ;
2156
+ }
2157
+ }
2158
+
2068
2159
#[ test]
2069
2160
fn test_cbrt ( ) {
2070
2161
assert ! ( close( _0_0i. cbrt( ) , _0_0i) ) ;
0 commit comments