-
Notifications
You must be signed in to change notification settings - Fork 64
WIP: Wirtinger support #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
1548cbc
88bb756
e3ce538
186009a
2618146
aa7a84a
20e7134
c42a953
2ac8801
61a60ec
41f1071
93a7b14
c904407
a83da2f
4d21e1a
71e0c7b
491f781
06ccb14
6ce400c
b960a09
3bc7e22
e3bf56d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,13 +41,44 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. | |
|
||
@inline Base.conj(x::AbstractDifferential) = x | ||
|
||
##### | ||
##### `AbstractWirtinger` | ||
##### | ||
|
||
""" | ||
AbstractWirtinger <: AbstractDifferential | ||
|
||
Represents the differential of a non-holomorphic function taking complex input. | ||
|
||
All subtypes implement [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref). | ||
|
||
All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper. | ||
E.g., a typical differential would look like this: | ||
``` | ||
Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number})) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After thinking about this quite a bit, I think this order makes the most sense, since this way, we can avoid allocations completely for derivatives like |
||
``` | ||
`@thunk` and `AbstractArray` are, of course, optional. | ||
""" | ||
abstract type AbstractWirtinger <: AbstractDifferential end | ||
|
||
wirtinger_primal(x) = x | ||
wirtinger_conjugate(::Any) = Zero() | ||
|
||
extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) | ||
|
||
Base.iterate(x::AbstractWirtinger) = (x, nothing) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think, I am going to get rid of this, since e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about making this behave like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think the iterate Overloads can probably be removed from everything |
||
Base.iterate(::AbstractWirtinger, ::Any) = nothing | ||
|
||
# `conj` is not defined for `AbstractWirtinger`. | ||
# Need this method to override the definition of `conj` for `AbstractDifferential`. | ||
Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x)) | ||
|
||
##### | ||
##### `Wirtinger` | ||
##### | ||
|
||
""" | ||
Wirtinger(primal::Union{Number,AbstractDifferential}, | ||
conjugate::Union{Number,AbstractDifferential}) | ||
Wirtinger(primal, conjugate) <: [`AbstractWirtinger`](@ref) | ||
|
||
Returns a `Wirtinger` instance representing the complex differential: | ||
|
||
|
@@ -60,32 +91,43 @@ where `primal` corresponds to `∂f/∂z * dz` and `conjugate` corresponds to ` | |
The two fields of the returned instance can be accessed generically via the | ||
[`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref) methods. | ||
""" | ||
struct Wirtinger{P,C} <: AbstractDifferential | ||
struct Wirtinger{P,C} <: AbstractWirtinger | ||
primal::P | ||
conjugate::C | ||
function Wirtinger(primal::Union{Number,AbstractDifferential}, | ||
conjugate::Union{Number,AbstractDifferential}) | ||
return new{typeof(primal),typeof(conjugate)}(primal, conjugate) | ||
end | ||
end | ||
|
||
wirtinger_primal(x::Wirtinger) = x.primal | ||
wirtinger_primal(x) = x | ||
|
||
wirtinger_conjugate(x::Wirtinger) = x.conjugate | ||
wirtinger_conjugate(::Any) = Zero() | ||
|
||
extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type.")) | ||
|
||
Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal), | ||
broadcastable(w.conjugate)) | ||
|
||
Base.iterate(x::Wirtinger) = (x, nothing) | ||
Base.iterate(::Wirtinger, ::Any) = nothing | ||
|
||
# TODO: define `conj` for` `Wirtinger` | ||
Base.conj(x::Wirtinger) = throw(MethodError(conj, x)) | ||
##### | ||
##### `ComplexGradient` | ||
##### | ||
|
||
""" | ||
ComplexGradient(val) <: [`AbstractWirtinger`](@ref) | ||
|
||
Returns a `ComplexGradient` instance representing the complex differential: | ||
|
||
``` | ||
df = ∂f/∂Re(z) * dRe(z) + im * ∂f/∂Im(z) * dIm(z) | ||
``` | ||
|
||
where `f` is a `ℂ(^n) -> ℝ(^m)` function and `val` corresponds to `df`. | ||
""" | ||
struct ComplexGradient{T} <: AbstractWirtinger | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val::T | ||
end | ||
|
||
wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x)) | ||
wirtinger_conjugate(x::ComplexGradient) = (1//2) * x.val | ||
|
||
Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val)) | ||
|
||
##### | ||
##### `Casted` | ||
|
@@ -131,6 +173,7 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero()) | |
Base.iterate(x::Zero) = (x, nothing) | ||
Base.iterate(::Zero, ::Any) = nothing | ||
|
||
Base.real(::Zero) = Zero() | ||
|
||
##### | ||
##### `DNE` | ||
|
@@ -172,6 +215,7 @@ Base.Broadcast.broadcastable(::One) = Ref(One()) | |
Base.iterate(x::One) = (x, nothing) | ||
Base.iterate(::One, ::Any) = nothing | ||
|
||
Base.real(::One) = One() | ||
|
||
##### | ||
##### `AbstractThunk | ||
|
@@ -191,6 +235,16 @@ end | |
return element, (externed, new_state) | ||
end | ||
|
||
unthunk(x) = x | ||
unthunk(x::AbstractThunk) = unthunk(x()) | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
wirtinger_primal(::Union{AbstractThunk}) = | ||
throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) | ||
wirtinger_conjugate(::Union{AbstractThunk}) = | ||
throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first.")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weird Also not sure about this function but will wait and see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the |
||
|
||
Base.real(x::AbstractThunk) = real(x()) | ||
|
||
##### | ||
##### `Thunk` | ||
##### | ||
|
@@ -239,6 +293,11 @@ struct Thunk{F} <: AbstractThunk | |
f::F | ||
end | ||
|
||
""" | ||
@thunk body | ||
|
||
Returns `Thunk(() -> body)` | ||
""" | ||
macro thunk(body) | ||
return :(Thunk(() -> $(esc(body)))) | ||
end | ||
|
@@ -291,14 +350,24 @@ function itself, when that function is not a closure. | |
const NO_FIELDS = DNE() | ||
|
||
""" | ||
refine_differential(𝒟::Type, der) | ||
refine_differential([𝒟::Type, ]der) | ||
|
||
Converts, if required, a differential object `der` | ||
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), | ||
to another differential that is more suited for the domain given by the type 𝒟. | ||
Often this will behave as the identity function on `der`. | ||
""" | ||
function refine_differential end | ||
|
||
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) | ||
w = refine_differential(w) | ||
return wirtinger_primal(w) + wirtinger_conjugate(w) | ||
end | ||
refine_differential(::Any, der) = der # most of the time leave it alone. | ||
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, g::ComplexGradient) | ||
g = refine_differential(g.val) | ||
return real(g) | ||
end | ||
refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone. | ||
|
||
refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
refine_differential(der::Any) = der | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -156,7 +156,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials) | |
Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs] | ||
pushforward_returns = map(1:n_outputs) do output_i | ||
∂s = partials[output_i].args | ||
propagation_expr(𝒟, Δs, ∂s) | ||
frule_propagation_expr(𝒟, Δs, ∂s) | ||
end | ||
if n_outputs > 1 | ||
# For forward-mode we only return a tuple if output actually a tuple. | ||
|
@@ -193,7 +193,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials) | |
# 1 partial derivative per input | ||
pullback_returns = map(1:n_inputs) do input_i | ||
∂s = [partial.args[input_i] for partial in partials] | ||
propagation_expr(𝒟, Δs, ∂s) | ||
rrule_propagation_expr(𝒟, Δs, ∂s) | ||
end | ||
|
||
pullback = quote | ||
|
@@ -222,56 +222,32 @@ end | |
if it is taken at `1+1im` it returns `Complex{Int}`. | ||
At present it is ignored for non-Wirtinger derivatives. | ||
""" | ||
function propagation_expr(𝒟, Δs, ∂s) | ||
wirtinger_indices = findall(∂s) do ex | ||
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger | ||
end | ||
function frule_propagation_expr(𝒟, Δs, ∂s) | ||
∂s = map(esc, ∂s) | ||
if isempty(wirtinger_indices) | ||
return standard_propagation_expr(Δs, ∂s) | ||
else | ||
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) | ||
end | ||
∂_mul_Δs = [:(chain(@_thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)] | ||
return :(refine_differential($𝒟, +($(∂_mul_Δs...)))) | ||
end | ||
|
||
function standard_propagation_expr(Δs, ∂s) | ||
# This is basically Δs ⋅ ∂s | ||
|
||
# Notice: the thunking of `∂s[i] (potentially) saves us some computation | ||
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon | ||
# as the pullback is evaluated | ||
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] | ||
return :(+($(∂_mul_Δs...))) | ||
function rrule_propagation_expr(𝒟, Δs, ∂s) | ||
∂s = map(esc, ∂s) | ||
∂_mul_Δs = [:(chain($(Δs[i]), @_thunk($(∂s[i])))) for i in 1:length(∂s)] | ||
return :(refine_differential($𝒟, +($(∂_mul_Δs...)))) | ||
end | ||
|
||
function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) | ||
∂_mul_Δs_primal = Any[] | ||
∂_mul_Δs_conjugate = Any[] | ||
∂_wirtinger_defs = Any[] | ||
for i in 1:length(∂s) | ||
if i in wirtinger_indices | ||
Δi = Δs[i] | ||
∂i = Symbol(string(:∂, i)) | ||
push!(∂_wirtinger_defs, :($∂i = $(∂s[i]))) | ||
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi)) | ||
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi)) | ||
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi)) | ||
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi)) | ||
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄)) | ||
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄)) | ||
else | ||
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i])) | ||
push!(∂_mul_Δs_primal, ∂_mul_Δ) | ||
push!(∂_mul_Δs_conjugate, ∂_mul_Δ) | ||
""" | ||
@_thunk body | ||
|
||
Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref). | ||
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. | ||
""" | ||
macro _thunk(body) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't need to be a macro anymore as it is only called from with in a function of ASTs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you take a look, whether everything is escaped correctly now? I'm not 100% sure, I know how escaping works. |
||
if body isa Expr && body.head == :call | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fname = body.args[1] | ||
if fname in (:Wirtinger, :ComplexGradient) | ||
return :($fname($((:(@thunk $(esc(i))) for i in body.args[2:end])...))) | ||
end | ||
end | ||
primal_sum = :(+($(∂_mul_Δs_primal...))) | ||
conjugate_sum = :(+($(∂_mul_Δs_conjugate...))) | ||
return quote # This will be a block, so will have value equal to last statement | ||
$(∂_wirtinger_defs...) | ||
w = Wirtinger($primal_sum, $conjugate_sum) | ||
refine_differential($𝒟, w) | ||
end | ||
return :(@thunk $(esc(body))) | ||
end | ||
|
||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to write things this way, rather than as e.g.
if swap_order; a, b = b, a; end
orswap_order && _chain(inner, outer, false)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is that there's two orders here to consider. The first one, why we need the
chain
function at all, is concerning, which of the differentials is the partial of the outer and which one is the partial of the inner function. ForAbstractWirtinger
, this difference does matter, which is the purpose ofchain
. The second one is the order of multiplication, which matters ifinner
andouter
are non-commutative objects like matrices. They might still be of typeWirtinger
, onlywirtinger_primal
andwirtinger_conjugate
are matrices. In general, both orderings are relevant, inBase.:*
for example, we want to multiply the outer differential from the right, but this is not equivalent tochain(inner, outer)
. I will definitely explain this in a docstring.