From 1548cbcbb5d91a8c6811f2bb30a0e5808f5df8ac Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sun, 22 Sep 2019 22:39:30 +0200 Subject: [PATCH 01/22] remove type constraints for Wirtinger --- src/differentials.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index 5ad2f8818..807427de1 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -63,10 +63,6 @@ The two fields of the returned instance can be accessed generically via the struct Wirtinger{P,C} <: AbstractDifferential primal::P conjugate::C - function Wirtinger(primal::Union{Number,AbstractDifferential}, - conjugate::Union{Number,AbstractDifferential}) - return new{typeof(primal),typeof(conjugate)}(primal, conjugate) - end end wirtinger_primal(x::Wirtinger) = x.primal From 88bb75682ceece20e5d36fa1a182121cd91a5693 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Mon, 23 Sep 2019 16:15:11 +0200 Subject: [PATCH 02/22] introduce `AbstractWirtinger` and `ComplexGradient` --- src/differential_arithmetic.jl | 30 +++++++++++++++++++------ src/differentials.jl | 40 +++++++++++++++++++++++++--------- 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index e65748d34..6d8dfb60e 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -7,14 +7,15 @@ subtypes, as we know the full set that might be encountered. Thus we can avoid any ambiguities. Notice: - The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) + The precidence goes: (:AbstractWirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) Thus each of the @eval loops creating definitions of + and * defines the combination this type with all types of lower precidence. This means each eval loops is 1 item smaller than the previous. ==# -function Base.:*(a::Wirtinger, b::Wirtinger) +function Base.:*(a::Union{Complex,AbstractWirtinger}, + b::Union{Complex,AbstractWirtinger}) error(""" Cannot multiply two Wirtinger objects; this error likely means a `WirtingerRule` was inappropriately defined somewhere. Multiplication @@ -32,18 +33,33 @@ function Base.:*(a::Wirtinger, b::Wirtinger) """) end -function Base.:+(a::Wirtinger, b::Wirtinger) - return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate) +function Base.:+(a::AbstractWirtinger, b::AbstractWirtinger) + return Wirtinger(wirtinger_primal(a) + wirtinger_primal(b), + wirtinger_conjugate(a) + wirtinger_conjugate(b)) end -for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) - @eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero()) - @eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b +Base.:+(a::ComplexGradient, b::ComplexGradient) = ComplexGradient(a.val + b.val) + +for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk) + @eval Base.:+(a::AbstractWirtinger, b::$T) = a + Wirtinger(b, Zero()) + @eval Base.:+(a::$T, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b @eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b) @eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate) + + @eval Base.:*(a::ComplexGradient, b::$T) = ComplexGradient(a.val * b) + @eval Base.:*(a::$T, b::ComplexGradient) = ComplexGradient(a * b.val) end +Base.:+(a::AbstractWirtinger, b) = a + Wirtinger(b, Zero()) +Base.:+(a, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b + +Base.:*(a::Wirtinger, b::Real) = Wirtinger(a.primal * b, a.conjugate * b) +Base.:*(a::Real, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate) + +Base.:*(a::ComplexGradient, b::Real) = ComplexGradient(a.val * b) +Base.:*(a::Real, b::ComplexGradient) = ComplexGradient(a * b.val) + Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value)) Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value)) diff --git a/src/differentials.jl b/src/differentials.jl index 807427de1..a45b0c938 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -41,13 +41,29 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. @inline Base.conj(x::AbstractDifferential) = x +##### +##### `AbstractWirtinger` +##### + +abstract type AbstractWirtinger <: AbstractDifferential end + +wirtinger_primal(x) = x +wirtinger_conjugate(::Any) = Zero() + +extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) + +Base.iterate(x::AbstractWirtinger) = (x, nothing) +Base.iterate(::AbstractWirtinger, ::Any) = nothing + +# `conj` is not defined for `AbstractWirtinger` +Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x)) + ##### ##### `Wirtinger` ##### """ - Wirtinger(primal::Union{Number,AbstractDifferential}, - conjugate::Union{Number,AbstractDifferential}) + Wirtinger(primal, conjugate) Returns a `Wirtinger` instance representing the complex differential: @@ -60,18 +76,13 @@ where `primal` corresponds to `βˆ‚f/βˆ‚z * dz` and `conjugate` corresponds to ` The two fields of the returned instance can be accessed generically via the [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref) methods. """ -struct Wirtinger{P,C} <: AbstractDifferential +struct Wirtinger{P,C} <: AbstractWirtinger primal::P conjugate::C end wirtinger_primal(x::Wirtinger) = x.primal -wirtinger_primal(x) = x - wirtinger_conjugate(x::Wirtinger) = x.conjugate -wirtinger_conjugate(::Any) = Zero() - -extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type.")) Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal), broadcastable(w.conjugate)) @@ -79,9 +90,18 @@ Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal), Base.iterate(x::Wirtinger) = (x, nothing) Base.iterate(::Wirtinger, ::Any) = nothing -# TODO: define `conj` for` `Wirtinger` -Base.conj(x::Wirtinger) = throw(MethodError(conj, x)) +##### +##### `ComplexGradient` +##### + +struct ComplexGradient{T} <: AbstractWirtinger + val::T +end + +wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x)) +wirtinger_conjugate(x::ComplexGradient) = x.val / 2 +Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val)) ##### ##### `Casted` From e3ce5389435a5c022cc8ec5da6c73cf3a8f777f4 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Mon, 23 Sep 2019 20:39:46 +0200 Subject: [PATCH 03/22] add `chain` function --- src/ChainRulesCore.jl | 4 ++-- src/differential_arithmetic.jl | 37 ++++++++++++++++++++++++++++++++++ src/differentials.jl | 8 ++++++-- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 118e7f841..e0a134385 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,8 +4,8 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad export frule, rrule export wirtinger_conjugate, wirtinger_primal, refine_differential export @scalar_rule, @thunk -export extern, cast, store! -export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk +export extern, chain, cast, store! +export Wirtinger, ComplexGradient, Zero, One, Casted, DNE, Thunk, InplaceableThunk export NO_FIELDS include("differentials.jl") diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 6d8dfb60e..25d60919d 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -114,3 +114,40 @@ for T in (:Any,) @eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) end + +function chain(outer, inner; swap_order=false) + if swap_order + return Wirtinger( + wirtinger_primal(inner) * wirtinger_primal(outer) + + conj(wirtinger_conjugate(inner)) * wirtinger_conjugate(outer), + wirtinger_conjugate(inner) * wirtinger_primal(outer) + + conj(wirtinger_primal(inner) * wirtinger_conjugate(outer)) + ) |> refine_differential + end + return Wirtinger( + wirtinger_primal(outer) * wirtinger_primal(inner) + + wirtinger_conjugate(outer) * conj(wirtinger_conjugate(inner)), + wirtinger_primal(outer) * wirtinger_conjugate(inner) + + wirtinger_conjugate(outer) * conj(wirtinger_primal(inner)) + ) |> refine_differential +end + +function chain(outer::ComplexGradient, inner; swap_order=false) + if swap_order + return ComplexGradient( + wirtinger_conjugate(inner) + conj(wirtinger_primal(inner)) * + outer.val + ) + end + return ComplexGradient( + outer.val * + (wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) + ) +end + +function chain(outer::ComplexGradient, inner::ComplexGradient; swap_order=false) + if swap_order + return ComplexGradient(conj(inner.val) * outer.val) + end + return ComplexGradient(outer.val * conj(inner.val)) +end diff --git a/src/differentials.jl b/src/differentials.jl index a45b0c938..65750bb77 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -307,7 +307,7 @@ function itself, when that function is not a closure. const NO_FIELDS = DNE() """ - refine_differential(π’Ÿ::Type, der) + refine_differential([π’Ÿ::Type, ]der) Converts, if required, a differential object `der` (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), @@ -315,6 +315,10 @@ to another differential that is more suited for the domain given by the type Often this will behave as the identity function on `der`. """ function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) + w = refine_differential(w) return wirtinger_primal(w) + wirtinger_conjugate(w) end -refine_differential(::Any, der) = der # most of the time leave it alone. +refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone. + +refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal +refine_differential(der::Any) = der From 186009aeca867735db92121da655d7a899e9ca86 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Tue, 24 Sep 2019 00:19:44 +0200 Subject: [PATCH 04/22] make `swap_order` in `chain` a positional arg This, together with adding `@inline` makes constant-propagation possible. Also fix a bug from before. --- src/differential_arithmetic.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 25d60919d..cdaeaa59c 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -115,7 +115,7 @@ for T in (:Any,) @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) end -function chain(outer, inner; swap_order=false) +@inline function chain(outer, inner, swap_order=false) if swap_order return Wirtinger( wirtinger_primal(inner) * wirtinger_primal(outer) + @@ -132,10 +132,10 @@ function chain(outer, inner; swap_order=false) ) |> refine_differential end -function chain(outer::ComplexGradient, inner; swap_order=false) +@inline function chain(outer::ComplexGradient, inner, swap_order=false) if swap_order return ComplexGradient( - wirtinger_conjugate(inner) + conj(wirtinger_primal(inner)) * + (wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) * outer.val ) end @@ -145,7 +145,7 @@ function chain(outer::ComplexGradient, inner; swap_order=false) ) end -function chain(outer::ComplexGradient, inner::ComplexGradient; swap_order=false) +@inline function chain(outer::ComplexGradient, inner::ComplexGradient, swap_order=false) if swap_order return ComplexGradient(conj(inner.val) * outer.val) end From 2618146da8bf901494ff75720fa51578c4276b9d Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Tue, 24 Sep 2019 00:58:26 +0200 Subject: [PATCH 05/22] introduce a function `unwrap_wirtinger` --- src/differential_arithmetic.jl | 9 ++++++--- src/differentials.jl | 7 +++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index cdaeaa59c..7acf96046 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -115,7 +115,10 @@ for T in (:Any,) @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) end -@inline function chain(outer, inner, swap_order=false) +@inline chain(outer, inner, swap_order=false) = + _chain(unwrap_wirtiner(outer), unwrap_wirtinger(inner), swap_order) + +@inline function _chain(outer, inner, swap_order) if swap_order return Wirtinger( wirtinger_primal(inner) * wirtinger_primal(outer) + @@ -132,7 +135,7 @@ end ) |> refine_differential end -@inline function chain(outer::ComplexGradient, inner, swap_order=false) +@inline function _chain(outer::ComplexGradient, inner, swap_order) if swap_order return ComplexGradient( (wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) * @@ -145,7 +148,7 @@ end ) end -@inline function chain(outer::ComplexGradient, inner::ComplexGradient, swap_order=false) +@inline function _chain(outer::ComplexGradient, inner::ComplexGradient, swap_order) if swap_order return ComplexGradient(conj(inner.val) * outer.val) end diff --git a/src/differentials.jl b/src/differentials.jl index 65750bb77..fdbdacab4 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -47,8 +47,15 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. abstract type AbstractWirtinger <: AbstractDifferential end +unwrap_wirtinger(x) = x +unwrap_wirtinger(x::Union{Casted,AbstractThunk}) = unwrap_wirtinger(extern(x)) + wirtinger_primal(x) = x +wirtinger_primal(x::Union{Casted,AbstractThunk}) = + throw(ArgumentError("`wirtinger_primal` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first") wirtinger_conjugate(::Any) = Zero() +wirtinger_primal(x::Union{Casted,AbstractThunk}) = + throw(ArgumentError("`wirtinger_conjugate` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first") extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) From aa7a84ad47ee352df40353a48bd9e3d8bdebbf38 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Tue, 24 Sep 2019 01:12:27 +0200 Subject: [PATCH 06/22] stop using types before they are defined --- src/differentials.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index fdbdacab4..4d615537e 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -48,14 +48,9 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. abstract type AbstractWirtinger <: AbstractDifferential end unwrap_wirtinger(x) = x -unwrap_wirtinger(x::Union{Casted,AbstractThunk}) = unwrap_wirtinger(extern(x)) wirtinger_primal(x) = x -wirtinger_primal(x::Union{Casted,AbstractThunk}) = - throw(ArgumentError("`wirtinger_primal` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first") wirtinger_conjugate(::Any) = Zero() -wirtinger_primal(x::Union{Casted,AbstractThunk}) = - throw(ArgumentError("`wirtinger_conjugate` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first") extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) @@ -214,6 +209,13 @@ end return element, (externed, new_state) end +unwrap_wirtinger(x::Union{Casted,AbstractThunk}) = unwrap_wirtinger(extern(x)) + +wirtinger_primal(x::Union{Casted,AbstractThunk}) = + throw(ArgumentError("`wirtinger_primal` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first")) +wirtinger_primal(x::Union{Casted,AbstractThunk}) = + throw(ArgumentError("`wirtinger_conjugate` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first")) + ##### ##### `Thunk` ##### From 20e7134eb65b97592910042a08dcabb4d2d6bf92 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Wed, 25 Sep 2019 18:42:37 +0200 Subject: [PATCH 07/22] fix tests --- test/rules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/rules.jl b/test/rules.jl index e23680326..2e0e89dc5 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -97,7 +97,9 @@ end abs_to_pow(x::Complex, p), @setup(u = abs(x)), ( - p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u), + p == 0 ? Zero() : let v = p * u^(p-1) / 2u + Wirtinger(x' * v, x * v) + end, Ξ© * log(abs(x)) ) ) From c42a953642e2d2e568ce968a27b0b9d9d5e6ad57 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 27 Sep 2019 12:00:37 +0200 Subject: [PATCH 08/22] rename `unwrap_wirtinger` -> `unthunk` --- src/differentials.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index 4d615537e..832e42845 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -47,8 +47,6 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. abstract type AbstractWirtinger <: AbstractDifferential end -unwrap_wirtinger(x) = x - wirtinger_primal(x) = x wirtinger_conjugate(::Any) = Zero() @@ -209,12 +207,13 @@ end return element, (externed, new_state) end -unwrap_wirtinger(x::Union{Casted,AbstractThunk}) = unwrap_wirtinger(extern(x)) +unthunk(x) = x +unthunk(x::AbstractThunk) = unthunk(x()) -wirtinger_primal(x::Union{Casted,AbstractThunk}) = - throw(ArgumentError("`wirtinger_primal` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first")) -wirtinger_primal(x::Union{Casted,AbstractThunk}) = - throw(ArgumentError("`wirtinger_conjugate` is not defined for $(typeof(x)). Call `unwrap_wirtinger` first")) +wirtinger_primal(::Union{AbstractThunk}) = + throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) +wirtinger_conjugate(::Union{AbstractThunk}) = + throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) ##### ##### `Thunk` From 2ac880163fcbf3a134e259695ab231515cdadbc9 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 27 Sep 2019 12:03:15 +0200 Subject: [PATCH 09/22] fix `chain` function --- src/differential_arithmetic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 7acf96046..04e09d62f 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -116,7 +116,7 @@ for T in (:Any,) end @inline chain(outer, inner, swap_order=false) = - _chain(unwrap_wirtiner(outer), unwrap_wirtinger(inner), swap_order) + _chain(unthunk(outer), unthunk(inner), swap_order) @inline function _chain(outer, inner, swap_order) if swap_order From 61a60ec548fee833704e1e50633bfbb737c92694 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 27 Sep 2019 12:11:30 +0200 Subject: [PATCH 10/22] use the new `chain` function in `@scalar_rule` --- src/rule_definition_tools.jl | 58 ++++++------------------------------ 1 file changed, 9 insertions(+), 49 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index a06820e64..50cb7cf75 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -156,7 +156,7 @@ function scalar_frule_expr(π’Ÿ, f, call, setup_stmts, inputs, partials) Ξ”s = [Symbol(string(:Ξ”, i)) for i in 1:n_inputs] pushforward_returns = map(1:n_outputs) do output_i βˆ‚s = partials[output_i].args - propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) + frule_propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) end if n_outputs > 1 # For forward-mode we only return a tuple if output actually a tuple. @@ -193,7 +193,7 @@ function scalar_rrule_expr(π’Ÿ, f, call, setup_stmts, inputs, partials) # 1 partial derivative per input pullback_returns = map(1:n_inputs) do input_i βˆ‚s = [partial.args[input_i] for partial in partials] - propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) + rrule_propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) end pullback = quote @@ -222,56 +222,16 @@ end if it is taken at `1+1im` it returns `Complex{Int}`. At present it is ignored for non-Wirtinger derivatives. """ -function propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) - wirtinger_indices = findall(βˆ‚s) do ex - Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger - end +function frule_propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) βˆ‚s = map(esc, βˆ‚s) - if isempty(wirtinger_indices) - return standard_propagation_expr(Ξ”s, βˆ‚s) - else - return wirtinger_propagation_expr(π’Ÿ, wirtinger_indices, Ξ”s, βˆ‚s) - end -end - -function standard_propagation_expr(Ξ”s, βˆ‚s) - # This is basically Ξ”s β‹… βˆ‚s - - # Notice: the thunking of `βˆ‚s[i] (potentially) saves us some computation - # if `Ξ”s[i]` is a `AbstractDifferential` otherwise it is computed as soon - # as the pullback is evaluated - βˆ‚_mul_Ξ”s = [:(@thunk($(βˆ‚s[i])) * $(Ξ”s[i])) for i in 1:length(βˆ‚s)] - return :(+($(βˆ‚_mul_Ξ”s...))) + βˆ‚_mul_Ξ”s = [:(chain(@thunk($(βˆ‚s[i])), $(Ξ”s[i]))) for i in 1:length(βˆ‚s)] + return :(refine_differential($π’Ÿ, +($(βˆ‚_mul_Ξ”s...)))) end -function wirtinger_propagation_expr(π’Ÿ, wirtinger_indices, Ξ”s, βˆ‚s) - βˆ‚_mul_Ξ”s_primal = Any[] - βˆ‚_mul_Ξ”s_conjugate = Any[] - βˆ‚_wirtinger_defs = Any[] - for i in 1:length(βˆ‚s) - if i in wirtinger_indices - Ξ”i = Ξ”s[i] - βˆ‚i = Symbol(string(:βˆ‚, i)) - push!(βˆ‚_wirtinger_defs, :($βˆ‚i = $(βˆ‚s[i]))) - βˆ‚fβˆ‚i_mul_Ξ” = :(wirtinger_primal($βˆ‚i) * wirtinger_primal($Ξ”i)) - βˆ‚fβˆ‚iΜ„_mul_Ξ”Μ„ = :(conj(wirtinger_conjugate($βˆ‚i)) * wirtinger_conjugate($Ξ”i)) - βˆ‚fΜ„βˆ‚i_mul_Ξ” = :(wirtinger_conjugate($βˆ‚i) * wirtinger_primal($Ξ”i)) - βˆ‚fΜ„βˆ‚iΜ„_mul_Ξ”Μ„ = :(conj(wirtinger_primal($βˆ‚i)) * wirtinger_conjugate($Ξ”i)) - push!(βˆ‚_mul_Ξ”s_primal, :($βˆ‚fβˆ‚i_mul_Ξ” + $βˆ‚fβˆ‚iΜ„_mul_Ξ”Μ„)) - push!(βˆ‚_mul_Ξ”s_conjugate, :($βˆ‚fΜ„βˆ‚i_mul_Ξ” + $βˆ‚fΜ„βˆ‚iΜ„_mul_Ξ”Μ„)) - else - βˆ‚_mul_Ξ” = :(@thunk($(βˆ‚s[i])) * $(Ξ”s[i])) - push!(βˆ‚_mul_Ξ”s_primal, βˆ‚_mul_Ξ”) - push!(βˆ‚_mul_Ξ”s_conjugate, βˆ‚_mul_Ξ”) - end - end - primal_sum = :(+($(βˆ‚_mul_Ξ”s_primal...))) - conjugate_sum = :(+($(βˆ‚_mul_Ξ”s_conjugate...))) - return quote # This will be a block, so will have value equal to last statement - $(βˆ‚_wirtinger_defs...) - w = Wirtinger($primal_sum, $conjugate_sum) - refine_differential($π’Ÿ, w) - end +function rrule_propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) + βˆ‚s = map(esc, βˆ‚s) + βˆ‚_mul_Ξ”s = [:(chain($(Ξ”s[i]), @thunk($(βˆ‚s[i])))) for i in 1:length(βˆ‚s)] + return :(refine_differential($π’Ÿ, +($(βˆ‚_mul_Ξ”s...)))) end """ From 41f10715e28622aec29f05b5fbc0aa2a7324de9d Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 27 Sep 2019 12:12:27 +0200 Subject: [PATCH 11/22] fix tests accordingly --- test/rules.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index 2e0e89dc5..9cab840c0 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -46,9 +46,9 @@ end @testset "real input" begin # even though our rule was define in terms of Wirtinger, - # pushforward result will be real as real (even if seed is Compex) + # pushforward result will be real as real (even if seed is Complex) - x = rand(Float64) + x = 5.0 f, myabs2_pushforward = frule(myabs2, x) @test f === x^2 @@ -56,22 +56,22 @@ end df = @inferred myabs2_pushforward(NamedTuple(), Ξ”) @test df === x + x - Ξ” = rand(Complex{Int64}) + Ξ” = 2.0 + 3.0im df = @inferred myabs2_pushforward(NamedTuple(), Ξ”) - @test df === Ξ” * (x + x) + @test df === (Ξ” + conj(Ξ”)) * x end @testset "complex input" begin - z = rand(Complex{Float64}) + z = 5.0 + 7.0im f, myabs2_pushforward = frule(myabs2, z) @test f === abs2(z) df = @inferred myabs2_pushforward(NamedTuple(), One()) @test df === Wirtinger(z', z) - Ξ” = rand(Complex{Int64}) + Ξ” = 2.0 + 3.0im df = @inferred myabs2_pushforward(NamedTuple(), Ξ”) - @test df === Wirtinger(Ξ” * z', Ξ” * z) + @test df === Wirtinger(Ξ” * conj(z), conj(Ξ”) * z) end end @@ -134,11 +134,11 @@ end fx, f_pushforward = res df(Ξ”x, Ξ”p) = f_pushforward(NamedTuple(), Ξ”x, Ξ”p) - df_dx::Thunk = df(One(), Zero()) - df_dp::Thunk = df(Zero(), One()) + df_dx = df(One(), Zero()) + df_dp = df(Zero(), One()) @test fx == f(x, p) # Check we still get the normal value, right - @test df_dx() isa expected_type_df_dx - @test df_dp() isa expected_type_df_dp + @test df_dx isa expected_type_df_dx + @test df_dp isa expected_type_df_dp res = rrule(f, x, p) @@ -147,7 +147,7 @@ end dself, df_dx, df_dp = f_pullback(One()) @test fx == f(x, p) # Check we still get the normal value, right @test dself == NO_FIELDS - @test df_dx() isa expected_type_df_dx - @test df_dp() isa expected_type_df_dp + @test df_dx isa expected_type_df_dx + @test df_dp isa expected_type_df_dp end end From 93a7b14d63d989fa2a336e400954a8765b5f136e Mon Sep 17 00:00:00 2001 From: simeonschaub Date: Thu, 3 Oct 2019 13:38:31 +0200 Subject: [PATCH 12/22] Update src/differentials.jl Co-Authored-By: Nick Robinson --- src/differentials.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/differentials.jl b/src/differentials.jl index 832e42845..4117d4d67 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -55,7 +55,8 @@ extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot b Base.iterate(x::AbstractWirtinger) = (x, nothing) Base.iterate(::AbstractWirtinger, ::Any) = nothing -# `conj` is not defined for `AbstractWirtinger` +# `conj` is not defined for `AbstractWirtinger`. +# Need this method to override the definition of `conj` for `AbstractDifferential`. Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x)) ##### From c9044072c5955054538fde3d3dc1f0ceaf5dd366 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 5 Oct 2019 15:53:53 +0200 Subject: [PATCH 13/22] add chain(::Real, ::ComplexGradient) --- src/differential_arithmetic.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 04e09d62f..29206f75f 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -148,6 +148,14 @@ end ) end +@inline function _chain(outer::Real, inner::ComplexGradient, swap_order) + if swap_order + return ComplexGradient(inner.val * outer) + end + return ComplexGradient(outer * inner.val) +end + +# don't know if we actually need this, shouldn't really occur in actual code @inline function _chain(outer::ComplexGradient, inner::ComplexGradient, swap_order) if swap_order return ComplexGradient(conj(inner.val) * outer.val) From a83da2fa9f7dd67d5859db7f61b786a17186943b Mon Sep 17 00:00:00 2001 From: simeonschaub Date: Fri, 18 Oct 2019 16:13:56 +0200 Subject: [PATCH 14/22] Update src/differentials.jl Co-Authored-By: Nick Robinson --- src/differentials.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/differentials.jl b/src/differentials.jl index 4117d4d67..62dc46dc9 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -214,7 +214,7 @@ unthunk(x::AbstractThunk) = unthunk(x()) wirtinger_primal(::Union{AbstractThunk}) = throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) wirtinger_conjugate(::Union{AbstractThunk}) = - throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) + throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first.")) ##### ##### `Thunk` From 4d21e1a111e1f5a217af42928ebcb64e3d4bfefe Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 19 Oct 2019 15:24:39 +0200 Subject: [PATCH 15/22] special case `AbstractWirtinger` in `at_thunk` --- src/differentials.jl | 6 ++++++ test/differentials.jl | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/differentials.jl b/src/differentials.jl index 62dc46dc9..f682f5c05 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -265,6 +265,12 @@ struct Thunk{F} <: AbstractThunk end macro thunk(body) + if body isa Expr && body.head == :call + fname = body.args[1] + if fname in (:Wirtinger, :ComplexGradient) + return :($fname($((:(@thunk $i) for i in body.args[2:end])...))) + end + end return :(Thunk(() -> $(esc(body)))) end diff --git a/test/differentials.jl b/test/differentials.jl index 570b09d88..fdd4c204d 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -89,7 +89,7 @@ @test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4 # For most differentials, in most domains, this does nothing - for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0) + for der in (DNE(), @thunk(23), [1 2], One(), Zero(), 0.0) for π’Ÿ in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2])) @test refine_differential(π’Ÿ, der) === der end From 71e0c7b23c9107ba8525c3cc87e9624de8f00364 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 19 Oct 2019 16:14:50 +0200 Subject: [PATCH 16/22] add `refine_differential` for `ComplexGradient` --- src/differentials.jl | 6 ++++++ test/differentials.jl | 28 ++++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index f682f5c05..edb7db161 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -329,10 +329,16 @@ Converts, if required, a differential object `der` to another differential that is more suited for the domain given by the type π’Ÿ. Often this will behave as the identity function on `der`. """ +function refine_differential end + function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) w = refine_differential(w) return wirtinger_primal(w) + wirtinger_conjugate(w) end +function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, g::ComplexGradient) + g = refine_differential(g.val) + return real(g) +end refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone. refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal diff --git a/test/differentials.jl b/test/differentials.jl index fdd4c204d..ad7452a69 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -82,11 +82,31 @@ @testset "Refine Differential" begin - @test refine_differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2) - @test refine_differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2) + for (p, c) in ( + (2, -3), + (2.0 + im, 5.0 - 3.0im), + ([1+im, 2-im], [-3+im, 4+im]), + (@thunk(1+2), @thunk(4-3)), + ) + w = Wirtinger(p, c) + @testset "$w" begin + @test refine_differential(typeof(1.0 + 1im), w) === w + @test refine_differential(typeof([1.0 + 1im]), w) === w - @test refine_differential(typeof(1.2), Wirtinger(2,2)) == 4 - @test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4 + @test refine_differential(typeof(1.2), w) == p + c + @test refine_differential(typeof([1.2]), w) == p + c + end + + g = ComplexGradient(c) + @testset "$g" begin + @test refine_differential(typeof(1.0 + 1im), g) === g + @test refine_differential(typeof([1.0 + 1im]), g) === g + + c isa Thunk && continue + @test refine_differential(typeof(1.2), g) == real(c) + @test refine_differential(typeof([1.2]), g) == real(c) + end + end # For most differentials, in most domains, this does nothing for der in (DNE(), @thunk(23), [1 2], One(), Zero(), 0.0) From 491f781a1f81fef43bdbadb22dc4c81799f4496d Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 19 Oct 2019 16:21:10 +0200 Subject: [PATCH 17/22] overload `Base.real` for some differentials --- src/differentials.jl | 4 ++++ test/differentials.jl | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/differentials.jl b/src/differentials.jl index edb7db161..ef9813c6a 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -148,6 +148,7 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero()) Base.iterate(x::Zero) = (x, nothing) Base.iterate(::Zero, ::Any) = nothing +Base.real(::Zero) = Zero() ##### ##### `DNE` @@ -189,6 +190,7 @@ Base.Broadcast.broadcastable(::One) = Ref(One()) Base.iterate(x::One) = (x, nothing) Base.iterate(::One, ::Any) = nothing +Base.real(::One) = One() ##### ##### `AbstractThunk @@ -216,6 +218,8 @@ wirtinger_primal(::Union{AbstractThunk}) = wirtinger_conjugate(::Union{AbstractThunk}) = throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first.")) +Base.real(x::AbstractThunk) = real(x()) + ##### ##### `Thunk` ##### diff --git a/test/differentials.jl b/test/differentials.jl index ad7452a69..89638a4c4 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -102,7 +102,6 @@ @test refine_differential(typeof(1.0 + 1im), g) === g @test refine_differential(typeof([1.0 + 1im]), g) === g - c isa Thunk && continue @test refine_differential(typeof(1.2), g) == real(c) @test refine_differential(typeof([1.2]), g) == real(c) end From 06ccb14bde89c7449b2efaf27c394b304cf3d434 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 19 Oct 2019 16:54:02 +0200 Subject: [PATCH 18/22] add some docstrings --- src/differentials.jl | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index ef9813c6a..b18494416 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -45,6 +45,20 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. ##### `AbstractWirtinger` ##### +""" + AbstractWirtinger <: AbstractDifferential + +Represents the differential of a non-holomorphic function taking complex input. + +All subtypes implement [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref). + +All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper. +E.g., a typical differential would look like this: +``` +Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number})) +``` +`@thunk` and `AbstractArray` are, of course, optional. +""" abstract type AbstractWirtinger <: AbstractDifferential end wirtinger_primal(x) = x @@ -64,7 +78,7 @@ Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x)) ##### """ - Wirtinger(primal, conjugate) + Wirtinger(primal, conjugate) <: [`AbstractWirtinger`](@ref) Returns a `Wirtinger` instance representing the complex differential: @@ -95,12 +109,23 @@ Base.iterate(::Wirtinger, ::Any) = nothing ##### `ComplexGradient` ##### +""" + ComplexGradient(val) <: [`AbstractWirtinger`](@ref) + +Returns a `ComplexGradient` instance representing the complex differential: + +``` +df = βˆ‚f/βˆ‚Re(z) * dRe(z) + im * βˆ‚f/βˆ‚Im(z) * dIm(z) +``` + +where `f` is a `β„‚(^n) -> ℝ(^m)` function and `val` corresponds to `df`. +""" struct ComplexGradient{T} <: AbstractWirtinger val::T end wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x)) -wirtinger_conjugate(x::ComplexGradient) = x.val / 2 +wirtinger_conjugate(x::ComplexGradient) = (1//2) * x.val Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val)) @@ -268,6 +293,12 @@ struct Thunk{F} <: AbstractThunk f::F end +""" + @thunk body + +Returns `Thunk(() -> body)`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref). +In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. +""" macro thunk(body) if body isa Expr && body.head == :call fname = body.args[1] From 6ce400cb92d28170568375899a554aead7a0b0fe Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 19 Oct 2019 21:40:29 +0200 Subject: [PATCH 19/22] move `at_thunk`-magic into separate macro --- src/differentials.jl | 9 +-------- src/rule_definition_tools.jl | 20 ++++++++++++++++++-- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index b18494416..4e4d8432f 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -296,16 +296,9 @@ end """ @thunk body -Returns `Thunk(() -> body)`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref). -In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. +Returns `Thunk(() -> body)` """ macro thunk(body) - if body isa Expr && body.head == :call - fname = body.args[1] - if fname in (:Wirtinger, :ComplexGradient) - return :($fname($((:(@thunk $i) for i in body.args[2:end])...))) - end - end return :(Thunk(() -> $(esc(body)))) end diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 50cb7cf75..d69e9440d 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -224,16 +224,32 @@ end """ function frule_propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) βˆ‚s = map(esc, βˆ‚s) - βˆ‚_mul_Ξ”s = [:(chain(@thunk($(βˆ‚s[i])), $(Ξ”s[i]))) for i in 1:length(βˆ‚s)] + βˆ‚_mul_Ξ”s = [:(chain(@_thunk($(βˆ‚s[i])), $(Ξ”s[i]))) for i in 1:length(βˆ‚s)] return :(refine_differential($π’Ÿ, +($(βˆ‚_mul_Ξ”s...)))) end function rrule_propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) βˆ‚s = map(esc, βˆ‚s) - βˆ‚_mul_Ξ”s = [:(chain($(Ξ”s[i]), @thunk($(βˆ‚s[i])))) for i in 1:length(βˆ‚s)] + βˆ‚_mul_Ξ”s = [:(chain($(Ξ”s[i]), @_thunk($(βˆ‚s[i])))) for i in 1:length(βˆ‚s)] return :(refine_differential($π’Ÿ, +($(βˆ‚_mul_Ξ”s...)))) end +""" + @_thunk body + +Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref). +In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. +""" +macro _thunk(body) + if body isa Expr && body.head == :call + fname = body.args[1] + if fname in (:Wirtinger, :ComplexGradient) + return :($fname($((:(@thunk $(esc(i))) for i in body.args[2:end])...))) + end + end + return :(@thunk $(esc(body))) +end + """ propagator_name(f, propname) From b960a0912c95d8fbf85bbcbec4f8b932640ddd31 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 19 Oct 2019 22:26:57 +0200 Subject: [PATCH 20/22] make at_scalar_rule detect wrong Wirtinger rules Tests will break for now, and that's good. --- src/rule_definition_tools.jl | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d69e9440d..e3babf155 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -241,13 +241,33 @@ Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) o In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. """ macro _thunk(body) - if body isa Expr && body.head == :call - fname = body.args[1] - if fname in (:Wirtinger, :ComplexGradient) - return :($fname($((:(@thunk $(esc(i))) for i in body.args[2:end])...))) + return _thunk(body) +end + +function _thunk(body) + if body isa Expr + if body.head == :call + fname = body.args[1] + if fname in (:Wirtinger, :ComplexGradient) + return :($fname($(thunk_assert_no_wirtinger.(body.args[2:end])...))) + end + elseif body.head == :escape + return Expr(:escape, _thunk(body.args[1])) end end - return :(@thunk $(esc(body))) + return thunk_assert_no_wirtinger(body) +end + +thunk_assert_no_wirtinger(body) = quote + Thunk( + function() + res = $(esc(body)) + res isa AbstractWirtinger && error(""" + Couldn't automatically handle `AbstractWirtinger` in `@scalar_rule. + Make sure `Wirtinger`/`ComplexGradient` is the outermost function call or write the rule manually.""") + return res + end + ) end """ From 3bc7e229b18ef670a7f41b918c11407a78e5cf44 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 19 Oct 2019 22:59:06 +0200 Subject: [PATCH 21/22] use `_thunk` as a function, not macro --- src/rule_definition_tools.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index e3babf155..0163a7a8a 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -224,26 +224,22 @@ end """ function frule_propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) βˆ‚s = map(esc, βˆ‚s) - βˆ‚_mul_Ξ”s = [:(chain(@_thunk($(βˆ‚s[i])), $(Ξ”s[i]))) for i in 1:length(βˆ‚s)] + βˆ‚_mul_Ξ”s = [:(chain($(_thunk(βˆ‚s[i])), $(Ξ”s[i]))) for i in 1:length(βˆ‚s)] return :(refine_differential($π’Ÿ, +($(βˆ‚_mul_Ξ”s...)))) end function rrule_propagation_expr(π’Ÿ, Ξ”s, βˆ‚s) βˆ‚s = map(esc, βˆ‚s) - βˆ‚_mul_Ξ”s = [:(chain($(Ξ”s[i]), @_thunk($(βˆ‚s[i])))) for i in 1:length(βˆ‚s)] + βˆ‚_mul_Ξ”s = [:(chain($(Ξ”s[i]), $(_thunk(βˆ‚s[i])))) for i in 1:length(βˆ‚s)] return :(refine_differential($π’Ÿ, +($(βˆ‚_mul_Ξ”s...)))) end """ - @_thunk body + _thunk(body) Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref). In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. """ -macro _thunk(body) - return _thunk(body) -end - function _thunk(body) if body isa Expr if body.head == :call @@ -261,8 +257,8 @@ end thunk_assert_no_wirtinger(body) = quote Thunk( function() - res = $(esc(body)) - res isa AbstractWirtinger && error(""" + res = $body + res isa ChainRulesCore.AbstractWirtinger && error(""" Couldn't automatically handle `AbstractWirtinger` in `@scalar_rule. Make sure `Wirtinger`/`ComplexGradient` is the outermost function call or write the rule manually.""") return res From e3bf56d9c56f82e84d63a283158834790c44ce74 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 19 Oct 2019 23:22:55 +0200 Subject: [PATCH 22/22] implement some of @oxinabox's suggestions --- src/differentials.jl | 10 ++-------- src/rule_definition_tools.jl | 14 ++++++-------- test/differentials.jl | 3 --- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index 4e4d8432f..1b69ee191 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -66,9 +66,6 @@ wirtinger_conjugate(::Any) = Zero() extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) -Base.iterate(x::AbstractWirtinger) = (x, nothing) -Base.iterate(::AbstractWirtinger, ::Any) = nothing - # `conj` is not defined for `AbstractWirtinger`. # Need this method to override the definition of `conj` for `AbstractDifferential`. Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x)) @@ -102,9 +99,6 @@ wirtinger_conjugate(x::Wirtinger) = x.conjugate Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal), broadcastable(w.conjugate)) -Base.iterate(x::Wirtinger) = (x, nothing) -Base.iterate(::Wirtinger, ::Any) = nothing - ##### ##### `ComplexGradient` ##### @@ -238,9 +232,9 @@ end unthunk(x) = x unthunk(x::AbstractThunk) = unthunk(x()) -wirtinger_primal(::Union{AbstractThunk}) = +wirtinger_primal(::AbstractThunk) = throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) -wirtinger_conjugate(::Union{AbstractThunk}) = +wirtinger_conjugate(::AbstractThunk) = throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first.")) Base.real(x::AbstractThunk) = real(x()) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 0163a7a8a..a9f55eee9 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -241,15 +241,13 @@ Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) o In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. """ function _thunk(body) - if body isa Expr - if body.head == :call - fname = body.args[1] - if fname in (:Wirtinger, :ComplexGradient) - return :($fname($(thunk_assert_no_wirtinger.(body.args[2:end])...))) - end - elseif body.head == :escape - return Expr(:escape, _thunk(body.args[1])) + if Meta.isexpr(body, :call) + fname = body.args[1] + if fname in (:Wirtinger, :ComplexGradient) + return :($fname($(thunk_assert_no_wirtinger.(body.args[2:end])...))) end + elseif Meta.isexpr(body, :escape) + return Expr(:escape, _thunk(body.args[1])) end return thunk_assert_no_wirtinger(body) end diff --git a/test/differentials.jl b/test/differentials.jl index 89638a4c4..389e26cdc 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -12,9 +12,6 @@ # TODO: other + methods stack overflow @test_throws ErrorException w*w @test_throws ArgumentError extern(w) - for x in w - @test x === w - end @test broadcastable(w) == w @test_throws MethodError conj(w) end