Skip to content

Commit 40ecdf6

Browse files
committed
Addressed feedback, notable changes include adding trait bounds and sealing ComplexFloat.
1 parent d4ce0bd commit 40ecdf6

File tree

2 files changed

+50
-16
lines changed

2 files changed

+50
-16
lines changed

src/complex_float.rs

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
// Keeps us from accidentally creating a recursive impl rather than a real one.
22
#![deny(unconditional_recursion)]
3+
#![cfg(any(feature = "std", feature = "libm"))]
34

4-
use num_traits::{float::FloatCore, Float, FloatConst};
5+
use core::ops::Neg;
6+
7+
use num_traits::{float::FloatCore, Float, FloatConst, Num, NumCast, Signed};
58

69
use crate::Complex;
710

11+
mod private {
12+
use num_traits::{float::FloatCore, Float, FloatConst, Signed};
13+
14+
use crate::Complex;
15+
16+
pub trait Seal {}
17+
18+
impl<T> Seal for T where T: Float + FloatConst {}
19+
impl<T: Float + FloatCore + FloatConst + Signed> Seal for Complex<T> {}
20+
}
21+
822
/// Generic trait for floating point complex numbers
923
/// This trait defines methods which are common to complex floating point numbers and regular floating point numbers.
10-
#[cfg(any(feature = "std", feature = "libm"))]
11-
pub trait ComplexFloat {
12-
type Real;
24+
/// This trait is sealed to prevent it from being implemented by anything other than floating point scalars and [Complex] floats.
25+
pub trait ComplexFloat: Num + NumCast + Copy + Neg<Output = Self> + private::Seal {
26+
/// The type used to represent the real coefficients of this complex number.
27+
type Real: Float + FloatConst;
1328

1429
/// Returns `true` if this value is `NaN` and false otherwise.
1530
fn is_nan(self) -> bool;
@@ -22,11 +37,10 @@ pub trait ComplexFloat {
2237
fn is_finite(self) -> bool;
2338

2439
/// Returns `true` if the number is neither zero, infinite,
25-
/// [subnormal][subnormal], or `NaN`.
26-
/// [subnormal]: http://en.wikipedia.org/wiki/Denormal_number
40+
/// [subnormal](http://en.wikipedia.org/wiki/Denormal_number), or `NaN`.
2741
fn is_normal(self) -> bool;
2842

29-
/// Take the reciprocal (inverse) of a number, `1/x`.
43+
/// Take the reciprocal (inverse) of a number, `1/x`. See also [Complex::finv].
3044
fn recip(self) -> Self;
3145

3246
/// Raises `self` to a signed integer power.
@@ -47,6 +61,9 @@ pub trait ComplexFloat {
4761
/// Returns `2^(self)`.
4862
fn exp2(self) -> Self;
4963

64+
/// Returns `base^(self)`.
65+
fn expf(self, base: Self::Real) -> Self;
66+
5067
/// Returns the natural logarithm of the number.
5168
fn ln(self) -> Self;
5269

@@ -106,16 +123,21 @@ pub trait ComplexFloat {
106123
/// Returns the real part of the number.
107124
fn re(self) -> Self::Real;
108125

109-
/// Returns the imaginary part of the number which equals to zero.
126+
/// Returns the imaginary part of the number.
110127
fn im(self) -> Self::Real;
111128

112-
/// Returns the absolute value of the number.
129+
/// Returns the absolute value of the number. See also [Complex::norm]
113130
fn abs(self) -> Self::Real;
114131

132+
/// Returns the L1 norm `|re| + |im|` -- the [Manhattan distance] from the origin.
133+
///
134+
/// [Manhattan distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
135+
fn l1_norm(&self) -> Self::Real;
136+
115137
/// Computes the argument of the number.
116138
fn arg(self) -> Self::Real;
117139

118-
/// Comutes the complex conjugate of `self`.
140+
/// Computes the complex conjugate of the number.
119141
///
120142
/// Formula: `a+bi -> a-bi`
121143
fn conj(self) -> Self;
@@ -141,7 +163,6 @@ macro_rules! forward_ref {
141163
)*};
142164
}
143165

144-
#[cfg(any(feature = "std", feature = "libm"))]
145166
impl<T> ComplexFloat for T
146167
where
147168
T: Float + FloatConst,
@@ -156,7 +177,7 @@ where
156177
T::zero()
157178
}
158179

159-
fn abs(self) -> Self::Real {
180+
fn l1_norm(&self) -> Self::Real {
160181
self.abs()
161182
}
162183

@@ -178,6 +199,10 @@ where
178199
self
179200
}
180201

202+
fn expf(self, base: Self::Real) -> Self {
203+
base.powf(self)
204+
}
205+
181206
forward! {
182207
Float::is_normal(self) -> bool;
183208
Float::is_infinite(self) -> bool;
@@ -206,11 +231,11 @@ where
206231
Float::asinh(self) -> Self;
207232
Float::acosh(self) -> Self;
208233
Float::atanh(self) -> Self;
234+
Float::abs(self) -> Self;
209235
}
210236
}
211237

212-
#[cfg(any(feature = "std", feature = "libm"))]
213-
impl<T: Float + FloatCore + FloatConst> ComplexFloat for Complex<T> {
238+
impl<T: Float + FloatCore + FloatConst + Signed> ComplexFloat for Complex<T> {
214239
type Real = T;
215240

216241
fn re(self) -> Self::Real {
@@ -229,6 +254,10 @@ impl<T: Float + FloatCore + FloatConst> ComplexFloat for Complex<T> {
229254
self.finv()
230255
}
231256

257+
fn l1_norm(&self) -> Self::Real {
258+
Complex::l1_norm(self)
259+
}
260+
232261
forward! {
233262
Complex::arg(self) -> Self::Real;
234263
Complex::powc(self, exp: Complex<Self::Real>) -> Complex<Self::Real>;
@@ -244,6 +273,7 @@ impl<T: Float + FloatCore + FloatConst> ComplexFloat for Complex<T> {
244273
Complex::sqrt(self) -> Self;
245274
Complex::cbrt(self) -> Self;
246275
Complex::exp(self) -> Self;
276+
Complex::expf(self, base: Self::Real) -> Self;
247277
Complex::ln(self) -> Self;
248278
Complex::sin(self) -> Self;
249279
Complex::cos(self) -> Self;

src/lib.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ use num_traits::{FloatConst, Inv, MulAdd, Num, One, Pow, Signed, Zero};
3636
use num_traits::float::Float;
3737
use num_traits::float::FloatCore;
3838

39+
#[cfg(any(feature = "std", feature = "libm"))]
40+
mod complex_float;
41+
3942
mod cast;
4043
mod pow;
4144

42-
pub mod complex_float;
45+
#[cfg(any(feature = "std", feature = "libm"))]
46+
pub use complex_float::ComplexFloat;
4347

4448
#[cfg(feature = "rand")]
4549
mod crand;
@@ -574,7 +578,7 @@ impl<T: Float + FloatConst> Complex<T> {
574578
#[inline]
575579
pub fn exp2(self) -> Self {
576580
// formula: 2^(a + bi) = 2^a (cos(b*log2) + i*sin(b*log2))
577-
// = from_polar(e^a, b)
581+
// = from_polar(2^a, b*log2)
578582
Self::from_polar(self.re.exp2(), self.im * T::LN_2())
579583
}
580584

0 commit comments

Comments
 (0)