Skip to content

Remove Wirtinger #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!

export frule, rrule
export refine_differential, wirtinger_conjugate, wirtinger_primal
export @scalar_rule, @thunk
export extern, store!, unthunk
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Wirtinger, Zero
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero
export NO_FIELDS

include("compat.jl")

include("differentials/abstract_differential.jl")
include("differentials/wirtinger.jl")
include("differentials/zero.jl")
include("differentials/does_not_exist.jl")
include("differentials/one.jl")
Expand Down
34 changes: 1 addition & 33 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,12 @@ Thus we can avoid any ambiguities.

Notice:
The precedence goes:
`Wirtinger, Zero, DoesNotExist, One, AbstractThunk, Composite, Any`
`Zero, DoesNotExist, One, AbstractThunk, Composite, Any`
Thus each of the @eval loops creating definitions of + and *
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
==#


function Base.:*(a::Wirtinger, b::Wirtinger)
error("""
Cannot multiply two Wirtinger objects; this error likely means a
`WirtingerRule` was inappropriately defined somewhere. Multiplication
of two Wirtinger objects is not defined because chain rule application
often expands into a non-commutative operation in the Wirtinger
calculus. To put it another way: simply given two Wirtinger objects
and no other information, we can't know "locally" which components to
conjugate in order to implement the chain rule. We could pick a
convention; for example, we could define `a::Wirtinger * b::Wirtinger`
such that we assume the chain rule application is of the form `f_a ∘ f_b`
instead of `f_b ∘ f_a`. However, picking such a convention is likely to
lead to silently incorrect derivatives due to commutativity assumptions
in downstream generic code that deals with the reals. Thus, ChainRulesCore
makes this operation an error instead.
""")
end

function Base.:+(a::Wirtinger, b::Wirtinger)
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
end

for T in (:Zero, :DoesNotExist, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b

@eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b)
@eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)
end


Base.:+(::Zero, b::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:DoesNotExist, :One, :AbstractThunk, :Any)
Expand Down
10 changes: 0 additions & 10 deletions src/differentials/abstract_differential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,3 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
@inline extern(x) = x

@inline Base.conj(x::AbstractDifferential) = x

"""
refine_differential(𝒟::Type, der)

Converts, if required, a differential object `der`
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
to another differential that is more suited for the domain given by the type 𝒟.
Often this will behave as the identity function on `der`.
"""
refine_differential(::Any, der) = der # most of the time leave it alone.
44 changes: 0 additions & 44 deletions src/differentials/wirtinger.jl

This file was deleted.

70 changes: 9 additions & 61 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,8 @@ macro scalar_rule(call, maybe_setup, partials...)
)
f = call.args[1]

# An expression that when evaluated will return the type of the input domain.
# Multiple repetitions of this expression should optimize out. But if it does not then
# may need to move its definition into the body of the `rrule`/`frule`
𝒟 = :(typeof(first(promote($(call.args[2:end]...)))))

frule_expr = scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
rrule_expr = scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)

frule_expr = scalar_frule_expr(f, call, setup_stmts, inputs, partials)
rrule_expr = scalar_rrule_expr(f, call, setup_stmts, inputs, partials)

############################################################################
# Final return: building the expression to insert in the place of this macro
Expand Down Expand Up @@ -147,7 +141,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
return call, setup_stmts, inputs, partials
end

function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
n_outputs = length(partials)
n_inputs = length(inputs)

Expand All @@ -156,7 +150,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs]
pushforward_returns = map(1:n_outputs) do output_i
∂s = partials[output_i].args
propagation_expr(𝒟, Δs, ∂s)
propagation_expr(Δs, ∂s)
end
if n_outputs > 1
# For forward-mode we only return a tuple if output actually a tuple.
Expand All @@ -182,7 +176,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
end
end

function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
n_outputs = length(partials)
n_inputs = length(inputs)

Expand All @@ -193,7 +187,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
# 1 partial derivative per input
pullback_returns = map(1:n_inputs) do input_i
∂s = [partial.args[input_i] for partial in partials]
propagation_expr(𝒟, Δs, ∂s)
propagation_expr(Δs, ∂s)
end

pullback = quote
Expand All @@ -212,30 +206,14 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
end

"""
propagation_expr(𝒟, Δs, ∂s)
propagation_expr(Δs, ∂s)

Returns the expression for the propagation of
the input gradient `Δs` though the partials `∂s`.

𝒟 is an expression that when evaluated returns the type-of the input domain.
For example if the derivative is being taken at the point `1` it returns `Int`.
if it is taken at `1+1im` it returns `Complex{Int}`.
At present it is ignored for non-Wirtinger derivatives.
"""
function propagation_expr(𝒟, Δs, ∂s)
wirtinger_indices = findall(∂s) do ex
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger
end
∂s = map(esc, ∂s)
if isempty(wirtinger_indices)
return standard_propagation_expr(Δs, ∂s)
else
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
end
end

function standard_propagation_expr(Δs, ∂s)
function propagation_expr(Δs, ∂s)
# This is basically Δs ⋅ ∂s
∂s = map(esc, ∂s)

# Notice: the thunking of `∂s[i] (potentially) saves us some computation
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
Expand All @@ -244,36 +222,6 @@ function standard_propagation_expr(Δs, ∂s)
return :(+($(∂_mul_Δs...)))
end

function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
∂_mul_Δs_primal = Any[]
∂_mul_Δs_conjugate = Any[]
∂_wirtinger_defs = Any[]
for i in 1:length(∂s)
if i in wirtinger_indices
Δi = Δs[i]
∂i = Symbol(string(:∂, i))
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
else
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
push!(∂_mul_Δs_primal, ∂_mul_Δ)
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
end
end
primal_sum = :(+($(∂_mul_Δs_primal...)))
conjugate_sum = :(+($(∂_mul_Δs_conjugate...)))
return quote # This will be a block, so will have value equal to last statement
$(∂_wirtinger_defs...)
w = Wirtinger($primal_sum, $conjugate_sum)
refine_differential($𝒟, w)
end
end

"""
propagator_name(f, propname)

Expand Down
19 changes: 0 additions & 19 deletions test/differentials/wirtinger.jl

This file was deleted.

16 changes: 0 additions & 16 deletions test/differentials_common.jl

This file was deleted.

111 changes: 0 additions & 111 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,114 +38,3 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test rrx == 2
@test rr1 == 1
end


@testset "Basic Wirtinger scalar_rule" begin
myabs2(x) = abs2(x)
@scalar_rule(myabs2(x), Wirtinger(x', x))

@testset "real input" begin
# even though our rule was define in terms of Wirtinger,
# pushforward result will be real as real (even if seed is Compex)

x = rand(Float64)
f, myabs2_pushforward = frule(myabs2, x)
@test f === x^2

Δ = One()
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
@test df === x + x

Δ = rand(Complex{Int64})
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
@test df === Δ * (x + x)
end

@testset "complex input" begin
z = rand(Complex{Float64})
f, myabs2_pushforward = frule(myabs2, z)
@test f === abs2(z)

df = @inferred myabs2_pushforward(NamedTuple(), One())
@test df === Wirtinger(z', z)

Δ = rand(Complex{Int64})
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
@test df === Wirtinger(Δ * z', Δ * z)
end
end


@testset "Advanced Wirtinger @scalar_rule: abs_to_pow" begin
# This is based on SimeonSchaub excellent example:
# https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97

# This is much more complex than the previous case
# as it has many different types
# depending on input, and the output types do not always agree

abs_to_pow(x, p) = abs(x)^p
@scalar_rule(
abs_to_pow(x::Real, p),
(
p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x),
Ω * log(abs(x))
)
)

@scalar_rule(
abs_to_pow(x::Complex, p),
@setup(u = abs(x)),
(
p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u),
Ω * log(abs(x))
)
)


f = abs_to_pow
@testset "f($x, $p)" for (x, p) in Iterators.product(
(2, 3.4, -2.1, -10+0im, 2.3-2im),
(0, 1, 2, 4.3, -2.1, 1+.2im)
)
expected_type_df_dx =
if iszero(p)
Zero
elseif typeof(x) <: Complex
Wirtinger
elseif typeof(p) <: Complex
Complex
else
Real
end

expected_type_df_dp =
if typeof(p) <: Real
Real
else
Complex
end


res = frule(f, x, p)
@test res !== nothing # Check the rule was defined
fx, f_pushforward = res
df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)

df_dx::Thunk = df(One(), Zero())
df_dp::Thunk = df(Zero(), One())
@test fx == f(x, p) # Check we still get the normal value, right
@test df_dx() isa expected_type_df_dx
@test df_dp() isa expected_type_df_dp


res = rrule(f, x, p)
@test res !== nothing # Check the rule was defined
fx, f_pullback = res
dself, df_dx, df_dp = f_pullback(One())
@test fx == f(x, p) # Check we still get the normal value, right
@test dself == NO_FIELDS
@test df_dx() isa expected_type_df_dx
@test df_dp() isa expected_type_df_dp
end
end
Loading