Skip to content

Commit 71e0c7b

Browse files
committed
add refine_differential for ComplexGradient
1 parent 4d21e1a commit 71e0c7b

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

src/differentials.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,16 @@ Converts, if required, a differential object `der`
329329
to another differential that is more suited for the domain given by the type 𝒟.
330330
Often this will behave as the identity function on `der`.
331331
"""
332+
function refine_differential end
333+
332334
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
333335
w = refine_differential(w)
334336
return wirtinger_primal(w) + wirtinger_conjugate(w)
335337
end
338+
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, g::ComplexGradient)
339+
g = refine_differential(g.val)
340+
return real(g)
341+
end
336342
refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone.
337343

338344
refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal

test/differentials.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,31 @@
8282

8383

8484
@testset "Refine Differential" begin
85-
@test refine_differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2)
86-
@test refine_differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2)
85+
for (p, c) in (
86+
(2, -3),
87+
(2.0 + im, 5.0 - 3.0im),
88+
([1+im, 2-im], [-3+im, 4+im]),
89+
(@thunk(1+2), @thunk(4-3)),
90+
)
91+
w = Wirtinger(p, c)
92+
@testset "$w" begin
93+
@test refine_differential(typeof(1.0 + 1im), w) === w
94+
@test refine_differential(typeof([1.0 + 1im]), w) === w
8795

88-
@test refine_differential(typeof(1.2), Wirtinger(2,2)) == 4
89-
@test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4
96+
@test refine_differential(typeof(1.2), w) == p + c
97+
@test refine_differential(typeof([1.2]), w) == p + c
98+
end
99+
100+
g = ComplexGradient(c)
101+
@testset "$g" begin
102+
@test refine_differential(typeof(1.0 + 1im), g) === g
103+
@test refine_differential(typeof([1.0 + 1im]), g) === g
104+
105+
c isa Thunk && continue
106+
@test refine_differential(typeof(1.2), g) == real(c)
107+
@test refine_differential(typeof([1.2]), g) == real(c)
108+
end
109+
end
90110

91111
# For most differentials, in most domains, this does nothing
92112
for der in (DNE(), @thunk(23), [1 2], One(), Zero(), 0.0)

0 commit comments

Comments
 (0)