Skip to content

Commit 2618146

Browse files
committed
introduce a function unwrap_wirtinger
1 parent 186009a commit 2618146

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/differential_arithmetic.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ for T in (:Any,)
115115
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
116116
end
117117

118-
@inline function chain(outer, inner, swap_order=false)
118+
@inline chain(outer, inner, swap_order=false) =
119+
_chain(unwrap_wirtiner(outer), unwrap_wirtinger(inner), swap_order)
120+
121+
@inline function _chain(outer, inner, swap_order)
119122
if swap_order
120123
return Wirtinger(
121124
wirtinger_primal(inner) * wirtinger_primal(outer) +
@@ -132,7 +135,7 @@ end
132135
) |> refine_differential
133136
end
134137

135-
@inline function chain(outer::ComplexGradient, inner, swap_order=false)
138+
@inline function _chain(outer::ComplexGradient, inner, swap_order)
136139
if swap_order
137140
return ComplexGradient(
138141
(wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) *
@@ -145,7 +148,7 @@ end
145148
)
146149
end
147150

148-
@inline function chain(outer::ComplexGradient, inner::ComplexGradient, swap_order=false)
151+
@inline function _chain(outer::ComplexGradient, inner::ComplexGradient, swap_order)
149152
if swap_order
150153
return ComplexGradient(conj(inner.val) * outer.val)
151154
end

src/differentials.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,15 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
4747

4848
abstract type AbstractWirtinger <: AbstractDifferential end
4949

50+
unwrap_wirtinger(x) = x
51+
unwrap_wirtinger(x::Union{Casted,AbstractThunk}) = unwrap_wirtinger(extern(x))
52+
5053
wirtinger_primal(x) = x
54+
wirtinger_primal(x::Union{Casted,AbstractThunk}) =
55+
throw(ArgumentError("`wirtinger_primal` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first")
5156
wirtinger_conjugate(::Any) = Zero()
57+
wirtinger_primal(x::Union{Casted,AbstractThunk}) =
58+
throw(ArgumentError("`wirtinger_conjugate` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first")
5259

5360
extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type."))
5461

0 commit comments

Comments
 (0)