Skip to content

Commit b0cf05b

Browse files
authored
Merge pull request #26 from JuliaDiff/oc/revert13
Revert #13
2 parents 198245d + 7bc6f65 commit b0cf05b

File tree

5 files changed

+31
-57
lines changed

5 files changed

+31
-57
lines changed

src/differentials.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,26 @@ add_thunk(a, b::Thunk) = add(a, extern(b))
303303
mul_thunk(a::Thunk, b::Thunk) = mul(extern(a), extern(b))
304304
mul_thunk(a::Thunk, b) = mul(extern(a), b)
305305
mul_thunk(a, b::Thunk) = mul(a, extern(b))
306+
307+
#####
308+
##### misc.
309+
#####
310+
311+
"""
312+
Wirtinger(primal::Real, conjugate::Real)
313+
314+
Return `add(primal, conjugate)`.
315+
316+
Actually implementing the Wirtinger calculus generally requires that the
317+
summed terms of the Wirtinger differential (`∂f/∂z * dz` and `∂f/∂z̄ * dz̄`) be
318+
stored individually. However, if both of these terms are real-valued, then
319+
downstream Wirtinger propagation mechanisms resolve to the same mechanisms as
320+
real-valued calculus, so that the terms' sum can be eagerly computed and
321+
propagated without requiring a special `Wirtinger` representation
322+
323+
This method primarily exists as an optimization.
324+
"""
325+
function Wirtinger(primal::Union{Real,DNE,Zero,One},
326+
conjugate::Union{Real,DNE,Zero,One})
327+
return add(primal, conjugate)
328+
end

src/rule_definition_tools.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,16 @@ macro scalar_rule(call, maybe_setup, partials...)
8080
end
8181
end
8282
if all(Meta.isexpr(partial, :tuple) for partial in partials)
83-
forward_rules = Any[rule_from_partials(input, partial.args...) for (input, partial) in zip(inputs, partials)]
83+
forward_rules = Any[rule_from_partials(partial.args...) for partial in partials]
8484
reverse_rules = Any[]
8585
for i in 1:length(inputs)
8686
reverse_partials = [partial.args[i] for partial in partials]
87-
push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...))
87+
push!(reverse_rules, rule_from_partials(reverse_partials...))
8888
end
8989
else
9090
@assert length(inputs) == 1 && all(!Meta.isexpr(partial, :tuple) for partial in partials)
91-
forward_rules = Any[rule_from_partials(input, partial) for (input, partial) in zip(inputs, partials)]
92-
reverse_rules = Any[rule_from_partials(inputs[1], partials...)]
91+
forward_rules = Any[rule_from_partials(partial) for partial in partials]
92+
reverse_rules = Any[rule_from_partials(partials...)]
9393
end
9494
forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...)
9595
reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...)
@@ -107,7 +107,7 @@ macro scalar_rule(call, maybe_setup, partials...)
107107
end
108108
end
109109

110-
function rule_from_partials(input_arg, ∂s...)
110+
function rule_from_partials(∂s...)
111111
wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s)
112112
∂s = map(esc, ∂s)
113113
Δs = [Symbol(string(, i)) for i in 1:length(∂s)]
@@ -140,7 +140,7 @@ function rule_from_partials(input_arg, ∂s...)
140140
conjugate_rule = :(Rule($Δs_tuple -> add($(∂_mul_Δs_conjugate...))))
141141
return quote
142142
$(∂_wirtinger_defs...)
143-
WirtingerRule(typeof($input_arg), $primal_rule, $conjugate_rule)
143+
WirtingerRule($primal_rule, $conjugate_rule)
144144
end
145145
end
146146
end

src/rules.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -184,25 +184,13 @@ DNERule(args...) = DNE()
184184
#####
185185

186186
"""
187-
WirtingerRule([𝒟::Type, ]P::AbstractRule, C::AbstractRule)
188-
Construct a `WirtingerRule` object, which is an `AbstractRule` that consists of
189-
an `AbstractRule` for both the primal derivative ``∂/∂x`` and the conjugate
190-
derivative ``∂/∂x̅``. If the domain `𝒟` is specified, return a `Rule` evaluating
191-
to `P(Δ) + C(Δ)` if `𝒟 <: Real`, otherwise return `WirtingerRule(P, C)`.
187+
TODO
192188
"""
193189
struct WirtingerRule{P<:AbstractRule,C<:AbstractRule} <: AbstractRule
194190
primal::P
195191
conjugate::C
196192
end
197193

198-
function WirtingerRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
199-
if 𝒟 <: Real || eltype(𝒟) <: Real
200-
return Rule((args...) -> add(primal(args...), conjugate(args...)))
201-
else
202-
return WirtingerRule(primal, conjugate)
203-
end
204-
end
205-
206194
function (rule::WirtingerRule)(args...)
207195
return Wirtinger(rule.primal(args...), rule.conjugate(args...))
208196
end

test/rules.jl

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,41 +42,4 @@ dummy_identity(x) = x
4242
@test rule[1] == rule
4343
@test_throws BoundsError rule[2]
4444
end
45-
46-
@testset "WirtingerRule" begin
47-
myabs2(x) = abs2(x)
48-
49-
function frule(::typeof(myabs2), x)
50-
return abs2(x), WirtingerRule(
51-
typeof(x),
52-
Rule(Δx -> Δx * x'),
53-
Rule(Δx -> Δx * x)
54-
)
55-
end
56-
57-
# real input
58-
x = rand(Float64)
59-
f, _df = @inferred frule(myabs2, x)
60-
@test f === x^2
61-
62-
df = @inferred _df(One())
63-
@test df === x + x
64-
65-
Δ = rand(Complex{Int64})
66-
df = @inferred _df(Δ)
67-
@test df === Δ * (x + x)
68-
69-
70-
# complex input
71-
z = rand(Complex{Float64})
72-
f, _df = @inferred frule(myabs2, z)
73-
@test f === abs2(z)
74-
75-
df = @inferred _df(One())
76-
@test df === Wirtinger(z', z)
77-
78-
Δ = rand(Complex{Int64})
79-
df = @inferred _df(Δ)
80-
@test df === Wirtinger* z', Δ * z)
81-
end
8245
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using LinearAlgebra: Diagonal
55
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
66
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
77
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,
8-
DNE, Thunk, Casted, DNERule, WirtingerRule
8+
DNE, Thunk, Casted, DNERule
99
using Base.Broadcast: broadcastable
1010

1111
@testset "ChainRulesCore" begin

0 commit comments

Comments
 (0)