Skip to content

Commit 9a56965

Browse files
authored
Merge pull request #79 from JuliaDiff/ox/nowirtinger
Remove Wirtinger
2 parents 71d44a9 + 9e5e1e3 commit 9a56965

File tree

9 files changed

+11
-300
lines changed

9 files changed

+11
-300
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,14 @@ module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
33

44
export frule, rrule
5-
export refine_differential, wirtinger_conjugate, wirtinger_primal
65
export @scalar_rule, @thunk
76
export extern, store!, unthunk
8-
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Wirtinger, Zero
7+
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero
98
export NO_FIELDS
109

1110
include("compat.jl")
1211

1312
include("differentials/abstract_differential.jl")
14-
include("differentials/wirtinger.jl")
1513
include("differentials/zero.jl")
1614
include("differentials/does_not_exist.jl")
1715
include("differentials/one.jl")

src/differential_arithmetic.jl

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,12 @@ Thus we can avoid any ambiguities.
88
99
Notice:
1010
The precedence goes:
11-
`Wirtinger, Zero, DoesNotExist, One, AbstractThunk, Composite, Any`
11+
`Zero, DoesNotExist, One, AbstractThunk, Composite, Any`
1212
Thus each of the @eval loops creating definitions of + and *
1313
defines the combination this type with all types of lower precidence.
1414
This means each eval loops is 1 item smaller than the previous.
1515
==#
1616

17-
18-
function Base.:*(a::Wirtinger, b::Wirtinger)
19-
error("""
20-
Cannot multiply two Wirtinger objects; this error likely means a
21-
`WirtingerRule` was inappropriately defined somewhere. Multiplication
22-
of two Wirtinger objects is not defined because chain rule application
23-
often expands into a non-commutative operation in the Wirtinger
24-
calculus. To put it another way: simply given two Wirtinger objects
25-
and no other information, we can't know "locally" which components to
26-
conjugate in order to implement the chain rule. We could pick a
27-
convention; for example, we could define `a::Wirtinger * b::Wirtinger`
28-
such that we assume the chain rule application is of the form `f_a ∘ f_b`
29-
instead of `f_b ∘ f_a`. However, picking such a convention is likely to
30-
lead to silently incorrect derivatives due to commutativity assumptions
31-
in downstream generic code that deals with the reals. Thus, ChainRulesCore
32-
makes this operation an error instead.
33-
""")
34-
end
35-
36-
function Base.:+(a::Wirtinger, b::Wirtinger)
37-
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
38-
end
39-
40-
for T in (:Zero, :DoesNotExist, :One, :AbstractThunk, :Any)
41-
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
42-
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
43-
44-
@eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b)
45-
@eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)
46-
end
47-
48-
4917
Base.:+(::Zero, b::Zero) = Zero()
5018
Base.:*(::Zero, ::Zero) = Zero()
5119
for T in (:DoesNotExist, :One, :AbstractThunk, :Any)

src/differentials/abstract_differential.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,3 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
4040
@inline extern(x) = x
4141

4242
@inline Base.conj(x::AbstractDifferential) = x
43-
44-
"""
45-
refine_differential(𝒟::Type, der)
46-
47-
Converts, if required, a differential object `der`
48-
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
49-
to another differential that is more suited for the domain given by the type 𝒟.
50-
Often this will behave as the identity function on `der`.
51-
"""
52-
refine_differential(::Any, der) = der # most of the time leave it alone.

src/differentials/wirtinger.jl

Lines changed: 0 additions & 44 deletions
This file was deleted.

src/rule_definition_tools.jl

Lines changed: 9 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,8 @@ macro scalar_rule(call, maybe_setup, partials...)
7474
)
7575
f = call.args[1]
7676

77-
# An expression that when evaluated will return the type of the input domain.
78-
# Multiple repetitions of this expression should optimize out. But if it does not then
79-
# may need to move its definition into the body of the `rrule`/`frule`
80-
𝒟 = :(typeof(first(promote($(call.args[2:end]...)))))
81-
82-
frule_expr = scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
83-
rrule_expr = scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
84-
77+
frule_expr = scalar_frule_expr(f, call, setup_stmts, inputs, partials)
78+
rrule_expr = scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
8579

8680
############################################################################
8781
# Final return: building the expression to insert in the place of this macro
@@ -147,7 +141,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
147141
return call, setup_stmts, inputs, partials
148142
end
149143

150-
function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
144+
function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
151145
n_outputs = length(partials)
152146
n_inputs = length(inputs)
153147

@@ -156,7 +150,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
156150
Δs = [Symbol(string(, i)) for i in 1:n_inputs]
157151
pushforward_returns = map(1:n_outputs) do output_i
158152
∂s = partials[output_i].args
159-
propagation_expr(𝒟, Δs, ∂s)
153+
propagation_expr(Δs, ∂s)
160154
end
161155
if n_outputs > 1
162156
# For forward-mode we only return a tuple if output actually a tuple.
@@ -182,7 +176,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
182176
end
183177
end
184178

185-
function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
179+
function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
186180
n_outputs = length(partials)
187181
n_inputs = length(inputs)
188182

@@ -193,7 +187,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
193187
# 1 partial derivative per input
194188
pullback_returns = map(1:n_inputs) do input_i
195189
∂s = [partial.args[input_i] for partial in partials]
196-
propagation_expr(𝒟, Δs, ∂s)
190+
propagation_expr(Δs, ∂s)
197191
end
198192

199193
pullback = quote
@@ -212,30 +206,14 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
212206
end
213207

214208
"""
215-
propagation_expr(𝒟, Δs, ∂s)
209+
propagation_expr(Δs, ∂s)
216210
217211
Returns the expression for the propagation of
218212
the input gradient `Δs` though the partials `∂s`.
219-
220-
𝒟 is an expression that when evaluated returns the type-of the input domain.
221-
For example if the derivative is being taken at the point `1` it returns `Int`.
222-
if it is taken at `1+1im` it returns `Complex{Int}`.
223-
At present it is ignored for non-Wirtinger derivatives.
224213
"""
225-
function propagation_expr(𝒟, Δs, ∂s)
226-
wirtinger_indices = findall(∂s) do ex
227-
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger
228-
end
229-
∂s = map(esc, ∂s)
230-
if isempty(wirtinger_indices)
231-
return standard_propagation_expr(Δs, ∂s)
232-
else
233-
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
234-
end
235-
end
236-
237-
function standard_propagation_expr(Δs, ∂s)
214+
function propagation_expr(Δs, ∂s)
238215
# This is basically Δs ⋅ ∂s
216+
∂s = map(esc, ∂s)
239217

240218
# Notice: the thunking of `∂s[i] (potentially) saves us some computation
241219
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
@@ -244,36 +222,6 @@ function standard_propagation_expr(Δs, ∂s)
244222
return :(+($(∂_mul_Δs...)))
245223
end
246224

247-
function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
248-
∂_mul_Δs_primal = Any[]
249-
∂_mul_Δs_conjugate = Any[]
250-
∂_wirtinger_defs = Any[]
251-
for i in 1:length(∂s)
252-
if i in wirtinger_indices
253-
Δi = Δs[i]
254-
∂i = Symbol(string(:∂, i))
255-
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
256-
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
257-
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
258-
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
259-
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
260-
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
261-
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
262-
else
263-
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
264-
push!(∂_mul_Δs_primal, ∂_mul_Δ)
265-
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
266-
end
267-
end
268-
primal_sum = :(+($(∂_mul_Δs_primal...)))
269-
conjugate_sum = :(+($(∂_mul_Δs_conjugate...)))
270-
return quote # This will be a block, so will have value equal to last statement
271-
$(∂_wirtinger_defs...)
272-
w = Wirtinger($primal_sum, $conjugate_sum)
273-
refine_differential($𝒟, w)
274-
end
275-
end
276-
277225
"""
278226
propagator_name(f, propname)
279227

test/differentials/wirtinger.jl

Lines changed: 0 additions & 19 deletions
This file was deleted.

test/differentials_common.jl

Lines changed: 0 additions & 16 deletions
This file was deleted.

test/rules.jl

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -38,114 +38,3 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
3838
@test rrx == 2
3939
@test rr1 == 1
4040
end
41-
42-
43-
@testset "Basic Wirtinger scalar_rule" begin
44-
myabs2(x) = abs2(x)
45-
@scalar_rule(myabs2(x), Wirtinger(x', x))
46-
47-
@testset "real input" begin
48-
# even though our rule was define in terms of Wirtinger,
49-
# pushforward result will be real as real (even if seed is Compex)
50-
51-
x = rand(Float64)
52-
f, myabs2_pushforward = frule(myabs2, x)
53-
@test f === x^2
54-
55-
Δ = One()
56-
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
57-
@test df === x + x
58-
59-
Δ = rand(Complex{Int64})
60-
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
61-
@test df === Δ * (x + x)
62-
end
63-
64-
@testset "complex input" begin
65-
z = rand(Complex{Float64})
66-
f, myabs2_pushforward = frule(myabs2, z)
67-
@test f === abs2(z)
68-
69-
df = @inferred myabs2_pushforward(NamedTuple(), One())
70-
@test df === Wirtinger(z', z)
71-
72-
Δ = rand(Complex{Int64})
73-
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
74-
@test df === Wirtinger* z', Δ * z)
75-
end
76-
end
77-
78-
79-
@testset "Advanced Wirtinger @scalar_rule: abs_to_pow" begin
80-
# This is based on SimeonSchaub excellent example:
81-
# https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97
82-
83-
# This is much more complex than the previous case
84-
# as it has many different types
85-
# depending on input, and the output types do not always agree
86-
87-
abs_to_pow(x, p) = abs(x)^p
88-
@scalar_rule(
89-
abs_to_pow(x::Real, p),
90-
(
91-
p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x),
92-
Ω * log(abs(x))
93-
)
94-
)
95-
96-
@scalar_rule(
97-
abs_to_pow(x::Complex, p),
98-
@setup(u = abs(x)),
99-
(
100-
p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u),
101-
Ω * log(abs(x))
102-
)
103-
)
104-
105-
106-
f = abs_to_pow
107-
@testset "f($x, $p)" for (x, p) in Iterators.product(
108-
(2, 3.4, -2.1, -10+0im, 2.3-2im),
109-
(0, 1, 2, 4.3, -2.1, 1+.2im)
110-
)
111-
expected_type_df_dx =
112-
if iszero(p)
113-
Zero
114-
elseif typeof(x) <: Complex
115-
Wirtinger
116-
elseif typeof(p) <: Complex
117-
Complex
118-
else
119-
Real
120-
end
121-
122-
expected_type_df_dp =
123-
if typeof(p) <: Real
124-
Real
125-
else
126-
Complex
127-
end
128-
129-
130-
res = frule(f, x, p)
131-
@test res !== nothing # Check the rule was defined
132-
fx, f_pushforward = res
133-
df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)
134-
135-
df_dx::Thunk = df(One(), Zero())
136-
df_dp::Thunk = df(Zero(), One())
137-
@test fx == f(x, p) # Check we still get the normal value, right
138-
@test df_dx() isa expected_type_df_dx
139-
@test df_dp() isa expected_type_df_dp
140-
141-
142-
res = rrule(f, x, p)
143-
@test res !== nothing # Check the rule was defined
144-
fx, f_pullback = res
145-
dself, df_dx, df_dp = f_pullback(One())
146-
@test fx == f(x, p) # Check we still get the normal value, right
147-
@test dself == NO_FIELDS
148-
@test df_dx() isa expected_type_df_dx
149-
@test df_dp() isa expected_type_df_dp
150-
end
151-
end

0 commit comments

Comments
 (0)