Skip to content

Commit e3ce538

Browse files
committed
add chain function
1 parent 88bb756 commit e3ce538

File tree

3 files changed

+45
-4
lines changed

3 files changed

+45
-4
lines changed

src/ChainRulesCore.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
44
export frule, rrule
55
export wirtinger_conjugate, wirtinger_primal, refine_differential
66
export @scalar_rule, @thunk
7-
export extern, cast, store!
8-
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
7+
export extern, chain, cast, store!
8+
export Wirtinger, ComplexGradient, Zero, One, Casted, DNE, Thunk, InplaceableThunk
99
export NO_FIELDS
1010

1111
include("differentials.jl")

src/differential_arithmetic.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,40 @@ for T in (:Any,)
114114
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
115115
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
116116
end
117+
118+
function chain(outer, inner; swap_order=false)
119+
if swap_order
120+
return Wirtinger(
121+
wirtinger_primal(inner) * wirtinger_primal(outer) +
122+
conj(wirtinger_conjugate(inner)) * wirtinger_conjugate(outer),
123+
wirtinger_conjugate(inner) * wirtinger_primal(outer) +
124+
conj(wirtinger_primal(inner) * wirtinger_conjugate(outer))
125+
) |> refine_differential
126+
end
127+
return Wirtinger(
128+
wirtinger_primal(outer) * wirtinger_primal(inner) +
129+
wirtinger_conjugate(outer) * conj(wirtinger_conjugate(inner)),
130+
wirtinger_primal(outer) * wirtinger_conjugate(inner) +
131+
wirtinger_conjugate(outer) * conj(wirtinger_primal(inner))
132+
) |> refine_differential
133+
end
134+
135+
function chain(outer::ComplexGradient, inner; swap_order=false)
136+
if swap_order
137+
return ComplexGradient(
138+
wirtinger_conjugate(inner) + conj(wirtinger_primal(inner)) *
139+
outer.val
140+
)
141+
end
142+
return ComplexGradient(
143+
outer.val *
144+
(wirtinger_conjugate(inner) + conj(wirtinger_primal(inner)))
145+
)
146+
end
147+
148+
function chain(outer::ComplexGradient, inner::ComplexGradient; swap_order=false)
149+
if swap_order
150+
return ComplexGradient(conj(inner.val) * outer.val)
151+
end
152+
return ComplexGradient(outer.val * conj(inner.val))
153+
end

src/differentials.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,14 +307,18 @@ function itself, when that function is not a closure.
307307
const NO_FIELDS = DNE()
308308

309309
"""
310-
refine_differential(𝒟::Type, der)
310+
refine_differential([𝒟::Type, ]der)
311311
312312
Converts, if required, a differential object `der`
313313
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
314314
to another differential that is more suited for the domain given by the type 𝒟.
315315
Often this will behave as the identity function on `der`.
316316
"""
317317
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
318+
w = refine_differential(w)
318319
return wirtinger_primal(w) + wirtinger_conjugate(w)
319320
end
320-
refine_differential(::Any, der) = der # most of the time leave it alone.
321+
refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone.
322+
323+
refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal
324+
refine_differential(der::Any) = der

0 commit comments

Comments
 (0)