Skip to content

Commit f34bf9c

Browse files
simeonschauboxinabox
authored andcommitted
fix wirtinger mechanism (#29)
* fix wirtinger mechanism * another attempt... * implement proposed changes * fix tests accordingly * fix another omission
1 parent 495a5f6 commit f34bf9c

File tree

5 files changed

+66
-31
lines changed

5 files changed

+66
-31
lines changed

src/differentials.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -303,26 +303,3 @@ add_thunk(a, b::Thunk) = add(a, extern(b))
303303
mul_thunk(a::Thunk, b::Thunk) = mul(extern(a), extern(b))
304304
mul_thunk(a::Thunk, b) = mul(extern(a), b)
305305
mul_thunk(a, b::Thunk) = mul(a, extern(b))
306-
307-
#####
308-
##### misc.
309-
#####
310-
311-
"""
312-
Wirtinger(primal::Real, conjugate::Real)
313-
314-
Return `add(primal, conjugate)`.
315-
316-
Actually implementing the Wirtinger calculus generally requires that the
317-
summed terms of the Wirtinger differential (`∂f/∂z * dz` and `∂f/∂z̄ * dz̄`) be
318-
stored individually. However, if both of these terms are real-valued, then
319-
downstream Wirtinger propagation mechanisms resolve to the same mechanisms as
320-
real-valued calculus, so that the terms' sum can be eagerly computed and
321-
propagated without requiring a special `Wirtinger` representation
322-
323-
This method primarily exists as an optimization.
324-
"""
325-
function Wirtinger(primal::Union{Real,DNE,Zero,One},
326-
conjugate::Union{Real,DNE,Zero,One})
327-
return add(primal, conjugate)
328-
end

src/rule_definition_tools.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,17 @@ macro scalar_rule(call, maybe_setup, partials...)
8080
end
8181
end
8282
if all(Meta.isexpr(partial, :tuple) for partial in partials)
83-
forward_rules = Any[rule_from_partials(partial.args...) for partial in partials]
83+
input_rep = :(first(promote($(inputs...)))) # stand-in with the right type for an input
84+
forward_rules = Any[rule_from_partials(input_rep, partial.args...) for partial in partials]
8485
reverse_rules = Any[]
8586
for i in 1:length(inputs)
8687
reverse_partials = [partial.args[i] for partial in partials]
87-
push!(reverse_rules, rule_from_partials(reverse_partials...))
88+
push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...))
8889
end
8990
else
9091
@assert length(inputs) == 1 && all(!Meta.isexpr(partial, :tuple) for partial in partials)
91-
forward_rules = Any[rule_from_partials(partial) for partial in partials]
92-
reverse_rules = Any[rule_from_partials(partials...)]
92+
forward_rules = Any[rule_from_partials(inputs[1], partial) for partial in partials]
93+
reverse_rules = Any[rule_from_partials(inputs[1], partials...)]
9394
end
9495
forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...)
9596
reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...)
@@ -107,7 +108,7 @@ macro scalar_rule(call, maybe_setup, partials...)
107108
end
108109
end
109110

110-
function rule_from_partials(∂s...)
111+
function rule_from_partials(input_arg, ∂s...)
111112
wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s)
112113
∂s = map(esc, ∂s)
113114
Δs = [Symbol(string(, i)) for i in 1:length(∂s)]
@@ -140,7 +141,7 @@ function rule_from_partials(∂s...)
140141
conjugate_rule = :(Rule($Δs_tuple -> add($(∂_mul_Δs_conjugate...))))
141142
return quote
142143
$(∂_wirtinger_defs...)
143-
WirtingerRule($primal_rule, $conjugate_rule)
144+
AbstractRule(typeof($input_arg), $primal_rule, $conjugate_rule)
144145
end
145146
end
146147
end

src/rules.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,13 @@ DNERule(args...) = DNE()
184184
#####
185185

186186
"""
187-
TODO
187+
WirtingerRule(primal::AbstractRule, conjugate::AbstractRule)
188+
189+
Construct a `WirtingerRule` object, which is an `AbstractRule` that consists of
190+
an `AbstractRule` for both the primal derivative ``∂/∂x`` and the conjugate
191+
derivative ``∂/∂x̅``. If the domain `𝒟` of the function might be real, consider
192+
calling `AbstractRule(𝒟, primal, conjugate)` instead, to make use of a more
193+
efficient representation wherever possible.
188194
"""
189195
struct WirtingerRule{P<:AbstractRule,C<:AbstractRule} <: AbstractRule
190196
primal::P
@@ -195,6 +201,20 @@ function (rule::WirtingerRule)(args...)
195201
return Wirtinger(rule.primal(args...), rule.conjugate(args...))
196202
end
197203

204+
"""
205+
AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
206+
207+
Return a `Rule` evaluating to `primal(Δ) + conjugate(Δ)` if `𝒟 <: Real`,
208+
otherwise return `WirtingerRule(P, C)`.
209+
"""
210+
function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
211+
if 𝒟 <: Real || eltype(𝒟) <: Real
212+
return Rule((args...) -> add(primal(args...), conjugate(args...)))
213+
else
214+
return WirtingerRule(primal, conjugate)
215+
end
216+
end
217+
198218
#####
199219
##### `frule`/`rrule`
200220
#####

test/rules.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,41 @@ dummy_identity(x) = x
4242
@test rule[1] == rule
4343
@test_throws BoundsError rule[2]
4444
end
45+
46+
@testset "WirtingerRule" begin
47+
myabs2(x) = abs2(x)
48+
49+
function frule(::typeof(myabs2), x)
50+
return abs2(x), AbstractRule(
51+
typeof(x),
52+
Rule(Δx -> Δx * x'),
53+
Rule(Δx -> Δx * x)
54+
)
55+
end
56+
57+
# real input
58+
x = rand(Float64)
59+
f, _df = @inferred frule(myabs2, x)
60+
@test f === x^2
61+
62+
df = @inferred _df(One())
63+
@test df === x + x
64+
65+
Δ = rand(Complex{Int64})
66+
df = @inferred _df(Δ)
67+
@test df === Δ * (x + x)
68+
69+
70+
# complex input
71+
z = rand(Complex{Float64})
72+
f, _df = @inferred frule(myabs2, z)
73+
@test f === abs2(z)
74+
75+
df = @inferred _df(One())
76+
@test df === Wirtinger(z', z)
77+
78+
Δ = rand(Complex{Int64})
79+
df = @inferred _df(Δ)
80+
@test df === Wirtinger* z', Δ * z)
81+
end
4582
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using LinearAlgebra: Diagonal
55
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
66
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
77
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,
8-
DNE, Thunk, Casted, DNERule
8+
DNE, Thunk, Casted, DNERule, WirtingerRule
99
using Base.Broadcast: broadcastable
1010

1111
@testset "ChainRulesCore" begin

0 commit comments

Comments
 (0)