Skip to content

WIP: Wirtinger support #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1548cbc
remove type constraints for Wirtinger
simeonschaub Sep 22, 2019
88bb756
introduce `AbstractWirtinger` and `ComplexGradient`
simeonschaub Sep 23, 2019
e3ce538
add `chain` function
simeonschaub Sep 23, 2019
186009a
make `swap_order` in `chain` a positional arg
simeonschaub Sep 23, 2019
2618146
introduce a function `unwrap_wirtinger`
simeonschaub Sep 23, 2019
aa7a84a
stop using types before they are defined
simeonschaub Sep 23, 2019
20e7134
fix tests
simeonschaub Sep 25, 2019
c42a953
rename `unwrap_wirtinger` -> `unthunk`
simeonschaub Sep 27, 2019
2ac8801
fix `chain` function
simeonschaub Sep 27, 2019
61a60ec
use the new `chain` function in `@scalar_rule`
simeonschaub Sep 27, 2019
41f1071
fix tests accordingly
simeonschaub Sep 27, 2019
93a7b14
Update src/differentials.jl
simeonschaub Oct 3, 2019
c904407
add chain(::Real, ::ComplexGradient)
simeonschaub Oct 5, 2019
a83da2f
Update src/differentials.jl
simeonschaub Oct 18, 2019
4d21e1a
special case `AbstractWirtinger` in `at_thunk`
simeonschaub Oct 19, 2019
71e0c7b
add `refine_differential` for `ComplexGradient`
simeonschaub Oct 19, 2019
491f781
overload `Base.real` for some differentials
simeonschaub Oct 19, 2019
06ccb14
add some docstrings
simeonschaub Oct 19, 2019
6ce400c
move `at_thunk`-magic into separate macro
simeonschaub Oct 19, 2019
b960a09
make at_scalar_rule detect wrong Wirtinger rules
simeonschaub Oct 19, 2019
3bc7e22
use `_thunk` as a function, not macro
simeonschaub Oct 19, 2019
e3bf56d
implement some of @oxinabox's suggestions
simeonschaub Oct 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
export frule, rrule
export wirtinger_conjugate, wirtinger_primal, refine_differential
export @scalar_rule, @thunk
export extern, cast, store!
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
export extern, chain, cast, store!
export Wirtinger, ComplexGradient, Zero, One, Casted, DNE, Thunk, InplaceableThunk
export NO_FIELDS

include("differentials.jl")
Expand Down
70 changes: 63 additions & 7 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ subtypes, as we know the full set that might be encountered.
Thus we can avoid any ambiguities.

Notice:
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
The precidence goes: (:AbstractWirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
Thus each of the @eval loops creating definitions of + and *
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
==#


function Base.:*(a::Wirtinger, b::Wirtinger)
function Base.:*(a::Union{Complex,AbstractWirtinger},
b::Union{Complex,AbstractWirtinger})
error("""
Cannot multiply two Wirtinger objects; this error likely means a
`WirtingerRule` was inappropriately defined somewhere. Multiplication
Expand All @@ -32,18 +33,33 @@ function Base.:*(a::Wirtinger, b::Wirtinger)
""")
end

function Base.:+(a::Wirtinger, b::Wirtinger)
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
function Base.:+(a::AbstractWirtinger, b::AbstractWirtinger)
return Wirtinger(wirtinger_primal(a) + wirtinger_primal(b),
wirtinger_conjugate(a) + wirtinger_conjugate(b))
end

for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
Base.:+(a::ComplexGradient, b::ComplexGradient) = ComplexGradient(a.val + b.val)

for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk)
@eval Base.:+(a::AbstractWirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b

@eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b)
@eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)

@eval Base.:*(a::ComplexGradient, b::$T) = ComplexGradient(a.val * b)
@eval Base.:*(a::$T, b::ComplexGradient) = ComplexGradient(a * b.val)
end

Base.:+(a::AbstractWirtinger, b) = a + Wirtinger(b, Zero())
Base.:+(a, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b

Base.:*(a::Wirtinger, b::Real) = Wirtinger(a.primal * b, a.conjugate * b)
Base.:*(a::Real, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)

Base.:*(a::ComplexGradient, b::Real) = ComplexGradient(a.val * b)
Base.:*(a::Real, b::ComplexGradient) = ComplexGradient(a * b.val)


Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
Expand Down Expand Up @@ -98,3 +114,43 @@ for T in (:Any,)
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
end

@inline chain(outer, inner, swap_order=false) =
_chain(unthunk(outer), unthunk(inner), swap_order)

@inline function _chain(outer, inner, swap_order)
if swap_order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to write things this way, rather than as e.g. if swap_order; a, b = b, a; end or swap_order && _chain(inner, outer, false)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that there's two orders here to consider. The first one, why we need the chain function at all, is concerning, which of the differentials is the partial of the outer and which one is the partial of the inner function. For AbstractWirtinger, this difference does matter, which is the purpose of chain. The second one is the order of multiplication, which matters if inner and outer are non-commutative objects like matrices. They might still be of type Wirtinger, only wirtinger_primal and wirtinger_conjugate are matrices. In general, both orderings are relevant, in Base.:* for example, we want to multiply the outer differential from the right, but this is not equivalent to chain(inner, outer). I will definitely explain this in a docstring.

return Wirtinger(
wirtinger_primal(inner) * wirtinger_primal(outer) +
conj(wirtinger_conjugate(inner)) * wirtinger_conjugate(outer),
wirtinger_conjugate(inner) * wirtinger_primal(outer) +
conj(wirtinger_primal(inner) * wirtinger_conjugate(outer))
) |> refine_differential
end
return Wirtinger(
wirtinger_primal(outer) * wirtinger_primal(inner) +
wirtinger_conjugate(outer) * conj(wirtinger_conjugate(inner)),
wirtinger_primal(outer) * wirtinger_conjugate(inner) +
wirtinger_conjugate(outer) * conj(wirtinger_primal(inner))
) |> refine_differential
end

@inline function _chain(outer::ComplexGradient, inner, swap_order)
if swap_order
return ComplexGradient(
(wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) *
outer.val
)
end
return ComplexGradient(
outer.val *
(wirtinger_conjugate(inner) + conj(wirtinger_primal(inner)))
)
end

@inline function _chain(outer::ComplexGradient, inner::ComplexGradient, swap_order)
if swap_order
return ComplexGradient(conj(inner.val) * outer.val)
end
return ComplexGradient(outer.val * conj(inner.val))
end
60 changes: 44 additions & 16 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,29 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.

@inline Base.conj(x::AbstractDifferential) = x

#####
##### `AbstractWirtinger`
#####

abstract type AbstractWirtinger <: AbstractDifferential end

wirtinger_primal(x) = x
wirtinger_conjugate(::Any) = Zero()

extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type."))

Base.iterate(x::AbstractWirtinger) = (x, nothing)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, I am going to get rid of this, since e.g. Wirtinger can now wrap arrays as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about making this behave like Wirtinger(i, j) for (i,j) in zip(primal, conjugate), but I think this would encourage people to collect this into an array, and we don't ever want to have AbstractArray{<:AbstractWirtinger}.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think the iterate Overloads can probably be removed from everything

Base.iterate(::AbstractWirtinger, ::Any) = nothing

# `conj` is not defined for `AbstractWirtinger`
Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x))

#####
##### `Wirtinger`
#####

"""
Wirtinger(primal::Union{Number,AbstractDifferential},
conjugate::Union{Number,AbstractDifferential})
Wirtinger(primal, conjugate)

Returns a `Wirtinger` instance representing the complex differential:

Expand All @@ -60,32 +76,32 @@ where `primal` corresponds to `∂f/∂z * dz` and `conjugate` corresponds to `
The two fields of the returned instance can be accessed generically via the
[`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref) methods.
"""
struct Wirtinger{P,C} <: AbstractDifferential
struct Wirtinger{P,C} <: AbstractWirtinger
primal::P
conjugate::C
function Wirtinger(primal::Union{Number,AbstractDifferential},
conjugate::Union{Number,AbstractDifferential})
return new{typeof(primal),typeof(conjugate)}(primal, conjugate)
end
end

wirtinger_primal(x::Wirtinger) = x.primal
wirtinger_primal(x) = x

wirtinger_conjugate(x::Wirtinger) = x.conjugate
wirtinger_conjugate(::Any) = Zero()

extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type."))

Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal),
broadcastable(w.conjugate))

Base.iterate(x::Wirtinger) = (x, nothing)
Base.iterate(::Wirtinger, ::Any) = nothing

# TODO: define `conj` for` `Wirtinger`
Base.conj(x::Wirtinger) = throw(MethodError(conj, x))
#####
##### `ComplexGradient`
#####

struct ComplexGradient{T} <: AbstractWirtinger
val::T
end

wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x))
wirtinger_conjugate(x::ComplexGradient) = x.val / 2

Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val))

#####
##### `Casted`
Expand Down Expand Up @@ -191,6 +207,14 @@ end
return element, (externed, new_state)
end

unthunk(x) = x
unthunk(x::AbstractThunk) = unthunk(x())

wirtinger_primal(::Union{AbstractThunk}) =
throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first."))
wirtinger_conjugate(::Union{AbstractThunk}) =
throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first."))

#####
##### `Thunk`
#####
Expand Down Expand Up @@ -291,14 +315,18 @@ function itself, when that function is not a closure.
const NO_FIELDS = DNE()

"""
refine_differential(𝒟::Type, der)
refine_differential([𝒟::Type, ]der)

Converts, if required, a differential object `der`
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
to another differential that is more suited for the domain given by the type 𝒟.
Often this will behave as the identity function on `der`.
"""
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
w = refine_differential(w)
return wirtinger_primal(w) + wirtinger_conjugate(w)
end
refine_differential(::Any, der) = der # most of the time leave it alone.
refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone.

refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal
refine_differential(der::Any) = der
58 changes: 9 additions & 49 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs]
pushforward_returns = map(1:n_outputs) do output_i
∂s = partials[output_i].args
propagation_expr(𝒟, Δs, ∂s)
frule_propagation_expr(𝒟, Δs, ∂s)
end
if n_outputs > 1
# For forward-mode we only return a tuple if output actually a tuple.
Expand Down Expand Up @@ -193,7 +193,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
# 1 partial derivative per input
pullback_returns = map(1:n_inputs) do input_i
∂s = [partial.args[input_i] for partial in partials]
propagation_expr(𝒟, Δs, ∂s)
rrule_propagation_expr(𝒟, Δs, ∂s)
end

pullback = quote
Expand Down Expand Up @@ -222,56 +222,16 @@ end
if it is taken at `1+1im` it returns `Complex{Int}`.
At present it is ignored for non-Wirtinger derivatives.
"""
function propagation_expr(𝒟, Δs, ∂s)
wirtinger_indices = findall(∂s) do ex
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger
end
function frule_propagation_expr(𝒟, Δs, ∂s)
∂s = map(esc, ∂s)
if isempty(wirtinger_indices)
return standard_propagation_expr(Δs, ∂s)
else
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
end
end

function standard_propagation_expr(Δs, ∂s)
# This is basically Δs ⋅ ∂s

# Notice: the thunking of `∂s[i] (potentially) saves us some computation
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
# as the pullback is evaluated
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
return :(+($(∂_mul_Δs...)))
∂_mul_Δs = [:(chain(@thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
end

function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
∂_mul_Δs_primal = Any[]
∂_mul_Δs_conjugate = Any[]
∂_wirtinger_defs = Any[]
for i in 1:length(∂s)
if i in wirtinger_indices
Δi = Δs[i]
∂i = Symbol(string(:∂, i))
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
else
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
push!(∂_mul_Δs_primal, ∂_mul_Δ)
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
end
end
primal_sum = :(+($(∂_mul_Δs_primal...)))
conjugate_sum = :(+($(∂_mul_Δs_conjugate...)))
return quote # This will be a block, so will have value equal to last statement
$(∂_wirtinger_defs...)
w = Wirtinger($primal_sum, $conjugate_sum)
refine_differential($𝒟, w)
end
function rrule_propagation_expr(𝒟, Δs, ∂s)
∂s = map(esc, ∂s)
∂_mul_Δs = [:(chain($(Δs[i]), @thunk($(∂s[i])))) for i in 1:length(∂s)]
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
end

"""
Expand Down
30 changes: 16 additions & 14 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,32 @@ end

@testset "real input" begin
# even though our rule was define in terms of Wirtinger,
# pushforward result will be real as real (even if seed is Compex)
# pushforward result will be real as real (even if seed is Complex)

x = rand(Float64)
x = 5.0
f, myabs2_pushforward = frule(myabs2, x)
@test f === x^2

Δ = One()
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
@test df === x + x

Δ = rand(Complex{Int64})
Δ = 2.0 + 3.0im
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
@test df === Δ * (x + x)
@test df === (Δ + conj(Δ)) * x
end

@testset "complex input" begin
z = rand(Complex{Float64})
z = 5.0 + 7.0im
f, myabs2_pushforward = frule(myabs2, z)
@test f === abs2(z)

df = @inferred myabs2_pushforward(NamedTuple(), One())
@test df === Wirtinger(z', z)

Δ = rand(Complex{Int64})
Δ = 2.0 + 3.0im
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
@test df === Wirtinger(Δ * z', Δ * z)
@test df === Wirtinger(Δ * conj(z), conj(Δ) * z)
end
end

Expand All @@ -97,7 +97,9 @@ end
abs_to_pow(x::Complex, p),
@setup(u = abs(x)),
(
p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u),
p == 0 ? Zero() : let v = p * u^(p-1) / 2u
Wirtinger(x' * v, x * v)
end,
Ω * log(abs(x))
)
)
Expand Down Expand Up @@ -132,11 +134,11 @@ end
fx, f_pushforward = res
df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)

df_dx::Thunk = df(One(), Zero())
df_dp::Thunk = df(Zero(), One())
df_dx = df(One(), Zero())
df_dp = df(Zero(), One())
@test fx == f(x, p) # Check we still get the normal value, right
@test df_dx() isa expected_type_df_dx
@test df_dp() isa expected_type_df_dp
@test df_dx isa expected_type_df_dx
@test df_dp isa expected_type_df_dp


res = rrule(f, x, p)
Expand All @@ -145,7 +147,7 @@ end
dself, df_dx, df_dp = f_pullback(One())
@test fx == f(x, p) # Check we still get the normal value, right
@test dself == NO_FIELDS
@test df_dx() isa expected_type_df_dx
@test df_dp() isa expected_type_df_dp
@test df_dx isa expected_type_df_dx
@test df_dp isa expected_type_df_dp
end
end