Skip to content

Commit 96db82d

Browse files
Merge #104
104: Adapted complex exp() function so that it can handle inf and nan arguments as well r=cuviper a=JorisDeRidder ### Why this PR? The current version of the complex exp() function is not able to handle arguments that contain +/- inf or NaN in their real or imaginary part. This impacts other complex functions that use the exp() function. For example, for the most widely used implementation of the complex Faddeeva function `w()`, the current `exp()` implementation leads to `w(1e160 - 1e159*i) = NaN + NaN *i` , while the correct value is `-5.586035480670854e-162 + 5.5860354806708545e-161 * i`. The underlying reason is that the current `exp()` implementation erroneously returns `exp(-inf + inf *i) = NaN + Nan *i` instead of the correct `0 + 0*i`. Cf also issue #103. ### Contents of this PR - I propose a modified complex exp() function that does deal with inf and nan. The added logic was strongly inspired by the one implemented in the `<complex>` C++ header that comes with clang++. - I added extra unit tests. The relevant values were taken from [this page](https://en.cppreference.com/w/cpp/numeric/complex/exp). - For the unit tests I also implemented `close_naninf()` and `close_naninf_to_tol()` as the existing functions `close()` and `close_to_tol()` are not able to deal with inf and nan. Co-authored-by: Joris De Ridder <joris.deridder@kuleuven.be> Co-authored-by: Joris De Ridder <5747893+JorisDeRidder@users.noreply.github.com>
2 parents 8cd50f8 + f294b51 commit 96db82d

File tree

1 file changed

+121
-10
lines changed

1 file changed

+121
-10
lines changed

src/lib.rs

Lines changed: 121 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,28 @@ impl<T: Float> Complex<T> {
198198
/// Computes `e^(self)`, where `e` is the base of the natural logarithm.
199199
#[inline]
200200
pub fn exp(self) -> Self {
201-
// formula: e^(a + bi) = e^a (cos(b) + i*sin(b))
202-
// = from_polar(e^a, b)
203-
Self::from_polar(self.re.exp(), self.im)
201+
// formula: e^(a + bi) = e^a (cos(b) + i*sin(b)) = from_polar(e^a, b)
202+
203+
let Complex { re, mut im } = self;
204+
// Treat the corner cases +∞, -∞, and NaN
205+
if re.is_infinite() {
206+
if re < T::zero() {
207+
if !im.is_finite() {
208+
return Self::new(T::zero(), T::zero());
209+
}
210+
} else {
211+
if im == T::zero() || !im.is_finite() {
212+
if im.is_infinite() {
213+
im = T::nan();
214+
}
215+
return Self::new(re, im);
216+
}
217+
}
218+
} else if re.is_nan() && im == T::zero() {
219+
return self;
220+
}
221+
222+
Self::from_polar(re.exp(), im)
204223
}
205224

206225
/// Computes the principal value of natural logarithm of `self`.
@@ -1578,14 +1597,31 @@ pub(crate) mod test {
15781597

15791598
use num_traits::{Num, One, Zero};
15801599

1581-
pub const _0_0i: Complex64 = Complex { re: 0.0, im: 0.0 };
1582-
pub const _1_0i: Complex64 = Complex { re: 1.0, im: 0.0 };
1583-
pub const _1_1i: Complex64 = Complex { re: 1.0, im: 1.0 };
1584-
pub const _0_1i: Complex64 = Complex { re: 0.0, im: 1.0 };
1585-
pub const _neg1_1i: Complex64 = Complex { re: -1.0, im: 1.0 };
1586-
pub const _05_05i: Complex64 = Complex { re: 0.5, im: 0.5 };
1600+
pub const _0_0i: Complex64 = Complex::new(0.0, 0.0);
1601+
pub const _1_0i: Complex64 = Complex::new(1.0, 0.0);
1602+
pub const _1_1i: Complex64 = Complex::new(1.0, 1.0);
1603+
pub const _0_1i: Complex64 = Complex::new(0.0, 1.0);
1604+
pub const _neg1_1i: Complex64 = Complex::new(-1.0, 1.0);
1605+
pub const _05_05i: Complex64 = Complex::new(0.5, 0.5);
15871606
pub const all_consts: [Complex64; 5] = [_0_0i, _1_0i, _1_1i, _neg1_1i, _05_05i];
1588-
pub const _4_2i: Complex64 = Complex { re: 4.0, im: 2.0 };
1607+
pub const _4_2i: Complex64 = Complex::new(4.0, 2.0);
1608+
pub const _1_infi: Complex64 = Complex::new(1.0, f64::INFINITY);
1609+
pub const _neg1_infi: Complex64 = Complex::new(-1.0, f64::INFINITY);
1610+
pub const _1_nani: Complex64 = Complex::new(1.0, f64::NAN);
1611+
pub const _neg1_nani: Complex64 = Complex::new(-1.0, f64::NAN);
1612+
pub const _inf_0i: Complex64 = Complex::new(f64::INFINITY, 0.0);
1613+
pub const _neginf_1i: Complex64 = Complex::new(f64::NEG_INFINITY, 1.0);
1614+
pub const _neginf_neg1i: Complex64 = Complex::new(f64::NEG_INFINITY, -1.0);
1615+
pub const _inf_1i: Complex64 = Complex::new(f64::INFINITY, 1.0);
1616+
pub const _inf_neg1i: Complex64 = Complex::new(f64::INFINITY, -1.0);
1617+
pub const _neginf_infi: Complex64 = Complex::new(f64::NEG_INFINITY, f64::INFINITY);
1618+
pub const _inf_infi: Complex64 = Complex::new(f64::INFINITY, f64::INFINITY);
1619+
pub const _neginf_nani: Complex64 = Complex::new(f64::NEG_INFINITY, f64::NAN);
1620+
pub const _inf_nani: Complex64 = Complex::new(f64::INFINITY, f64::NAN);
1621+
pub const _nan_0i: Complex64 = Complex::new(f64::NAN, 0.0);
1622+
pub const _nan_1i: Complex64 = Complex::new(f64::NAN, 1.0);
1623+
pub const _nan_neg1i: Complex64 = Complex::new(f64::NAN, -1.0);
1624+
pub const _nan_nani: Complex64 = Complex::new(f64::NAN, f64::NAN);
15891625

15901626
#[test]
15911627
fn test_consts() {
@@ -1736,6 +1772,56 @@ pub(crate) mod test {
17361772
close
17371773
}
17381774

1775+
// Version that also works if re or im are +inf, -inf, or nan
1776+
fn close_naninf(a: Complex64, b: Complex64) -> bool {
1777+
close_naninf_to_tol(a, b, 1.0e-10)
1778+
}
1779+
1780+
fn close_naninf_to_tol(a: Complex64, b: Complex64, tol: f64) -> bool {
1781+
let mut close = true;
1782+
1783+
// Compare the real parts
1784+
if a.re.is_finite() {
1785+
if b.re.is_finite() {
1786+
close = (a.re == b.re) || (a.re - b.re).abs() < tol;
1787+
} else {
1788+
close = false;
1789+
}
1790+
} else if (a.re.is_nan() && !b.re.is_nan())
1791+
|| (a.re.is_infinite()
1792+
&& a.re.is_sign_positive()
1793+
&& !(b.re.is_infinite() && b.re.is_sign_positive()))
1794+
|| (a.re.is_infinite()
1795+
&& a.re.is_sign_negative()
1796+
&& !(b.re.is_infinite() && b.re.is_sign_negative()))
1797+
{
1798+
close = false;
1799+
}
1800+
1801+
// Compare the imaginary parts
1802+
if a.im.is_finite() {
1803+
if b.im.is_finite() {
1804+
close &= (a.im == b.im) || (a.im - b.im).abs() < tol;
1805+
} else {
1806+
close = false;
1807+
}
1808+
} else if (a.im.is_nan() && !b.im.is_nan())
1809+
|| (a.im.is_infinite()
1810+
&& a.im.is_sign_positive()
1811+
&& !(b.im.is_infinite() && b.im.is_sign_positive()))
1812+
|| (a.im.is_infinite()
1813+
&& a.im.is_sign_negative()
1814+
&& !(b.im.is_infinite() && b.im.is_sign_negative()))
1815+
{
1816+
close = false;
1817+
}
1818+
1819+
if close == false {
1820+
println!("{:?} != {:?}", a, b);
1821+
}
1822+
close
1823+
}
1824+
17391825
#[test]
17401826
fn test_exp2() {
17411827
assert!(close(_0_0i.exp2(), _1_0i));
@@ -1760,6 +1846,31 @@ pub(crate) mod test {
17601846
(c + _0_1i.scale(f64::consts::PI * 2.0)).exp()
17611847
));
17621848
}
1849+
1850+
// The test values below were taken from https://en.cppreference.com/w/cpp/numeric/complex/exp
1851+
assert!(close_naninf(_1_infi.exp(), _nan_nani));
1852+
assert!(close_naninf(_neg1_infi.exp(), _nan_nani));
1853+
assert!(close_naninf(_1_nani.exp(), _nan_nani));
1854+
assert!(close_naninf(_neg1_nani.exp(), _nan_nani));
1855+
assert!(close_naninf(_inf_0i.exp(), _inf_0i));
1856+
assert!(close_naninf(_neginf_1i.exp(), 0.0 * Complex::cis(1.0)));
1857+
assert!(close_naninf(_neginf_neg1i.exp(), 0.0 * Complex::cis(-1.0)));
1858+
assert!(close_naninf(
1859+
_inf_1i.exp(),
1860+
f64::INFINITY * Complex::cis(1.0)
1861+
));
1862+
assert!(close_naninf(
1863+
_inf_neg1i.exp(),
1864+
f64::INFINITY * Complex::cis(-1.0)
1865+
));
1866+
assert!(close_naninf(_neginf_infi.exp(), _0_0i)); // Note: ±0±0i: signs of zeros are unspecified
1867+
assert!(close_naninf(_inf_infi.exp(), _inf_nani)); // Note: ±∞+NaN*i: sign of the real part is unspecified
1868+
assert!(close_naninf(_neginf_nani.exp(), _0_0i)); // Note: ±0±0i: signs of zeros are unspecified
1869+
assert!(close_naninf(_inf_nani.exp(), _inf_nani)); // Note: ±∞+NaN*i: sign of the real part is unspecified
1870+
assert!(close_naninf(_nan_0i.exp(), _nan_0i));
1871+
assert!(close_naninf(_nan_1i.exp(), _nan_nani));
1872+
assert!(close_naninf(_nan_neg1i.exp(), _nan_nani));
1873+
assert!(close_naninf(_nan_nani.exp(), _nan_nani));
17631874
}
17641875

17651876
#[test]

0 commit comments

Comments
 (0)