Skip to content

Commit 88bb756

Browse files
committed
introduce AbstractWirtinger and ComplexGradient
1 parent 1548cbc commit 88bb756

File tree

2 files changed

+53
-17
lines changed

2 files changed

+53
-17
lines changed

src/differential_arithmetic.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ subtypes, as we know the full set that might be encountered.
77
Thus we can avoid any ambiguities.
88
99
Notice:
10-
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
10+
The precidence goes: (:AbstractWirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
1111
Thus each of the @eval loops creating definitions of + and *
1212
defines the combination this type with all types of lower precidence.
1313
This means each eval loops is 1 item smaller than the previous.
1414
==#
1515

1616

17-
function Base.:*(a::Wirtinger, b::Wirtinger)
17+
function Base.:*(a::Union{Complex,AbstractWirtinger},
18+
b::Union{Complex,AbstractWirtinger})
1819
error("""
1920
Cannot multiply two Wirtinger objects; this error likely means a
2021
`WirtingerRule` was inappropriately defined somewhere. Multiplication
@@ -32,18 +33,33 @@ function Base.:*(a::Wirtinger, b::Wirtinger)
3233
""")
3334
end
3435

35-
function Base.:+(a::Wirtinger, b::Wirtinger)
36-
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
36+
function Base.:+(a::AbstractWirtinger, b::AbstractWirtinger)
37+
return Wirtinger(wirtinger_primal(a) + wirtinger_primal(b),
38+
wirtinger_conjugate(a) + wirtinger_conjugate(b))
3739
end
3840

39-
for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
40-
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
41-
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
41+
Base.:+(a::ComplexGradient, b::ComplexGradient) = ComplexGradient(a.val + b.val)
42+
43+
for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk)
44+
@eval Base.:+(a::AbstractWirtinger, b::$T) = a + Wirtinger(b, Zero())
45+
@eval Base.:+(a::$T, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b
4246

4347
@eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b)
4448
@eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)
49+
50+
@eval Base.:*(a::ComplexGradient, b::$T) = ComplexGradient(a.val * b)
51+
@eval Base.:*(a::$T, b::ComplexGradient) = ComplexGradient(a * b.val)
4552
end
4653

54+
Base.:+(a::AbstractWirtinger, b) = a + Wirtinger(b, Zero())
55+
Base.:+(a, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b
56+
57+
Base.:*(a::Wirtinger, b::Real) = Wirtinger(a.primal * b, a.conjugate * b)
58+
Base.:*(a::Real, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)
59+
60+
Base.:*(a::ComplexGradient, b::Real) = ComplexGradient(a.val * b)
61+
Base.:*(a::Real, b::ComplexGradient) = ComplexGradient(a * b.val)
62+
4763

4864
Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
4965
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))

src/differentials.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,29 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
4141

4242
@inline Base.conj(x::AbstractDifferential) = x
4343

44+
#####
45+
##### `AbstractWirtinger`
46+
#####
47+
48+
abstract type AbstractWirtinger <: AbstractDifferential end
49+
50+
wirtinger_primal(x) = x
51+
wirtinger_conjugate(::Any) = Zero()
52+
53+
extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type."))
54+
55+
Base.iterate(x::AbstractWirtinger) = (x, nothing)
56+
Base.iterate(::AbstractWirtinger, ::Any) = nothing
57+
58+
# `conj` is not defined for `AbstractWirtinger`
59+
Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x))
60+
4461
#####
4562
##### `Wirtinger`
4663
#####
4764

4865
"""
49-
Wirtinger(primal::Union{Number,AbstractDifferential},
50-
conjugate::Union{Number,AbstractDifferential})
66+
Wirtinger(primal, conjugate)
5167
5268
Returns a `Wirtinger` instance representing the complex differential:
5369
@@ -60,28 +76,32 @@ where `primal` corresponds to `∂f/∂z * dz` and `conjugate` corresponds to `
6076
The two fields of the returned instance can be accessed generically via the
6177
[`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref) methods.
6278
"""
63-
struct Wirtinger{P,C} <: AbstractDifferential
79+
struct Wirtinger{P,C} <: AbstractWirtinger
6480
primal::P
6581
conjugate::C
6682
end
6783

6884
wirtinger_primal(x::Wirtinger) = x.primal
69-
wirtinger_primal(x) = x
70-
7185
wirtinger_conjugate(x::Wirtinger) = x.conjugate
72-
wirtinger_conjugate(::Any) = Zero()
73-
74-
extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type."))
7586

7687
Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal),
7788
broadcastable(w.conjugate))
7889

7990
Base.iterate(x::Wirtinger) = (x, nothing)
8091
Base.iterate(::Wirtinger, ::Any) = nothing
8192

82-
# TODO: define `conj` for` `Wirtinger`
83-
Base.conj(x::Wirtinger) = throw(MethodError(conj, x))
93+
#####
94+
##### `ComplexGradient`
95+
#####
96+
97+
struct ComplexGradient{T} <: AbstractWirtinger
98+
val::T
99+
end
100+
101+
wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x))
102+
wirtinger_conjugate(x::ComplexGradient) = x.val / 2
84103

104+
Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val))
85105

86106
#####
87107
##### `Casted`

0 commit comments

Comments
 (0)