diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 22f8234b7..cc4005e64 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,16 +2,14 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! export frule, rrule -export refine_differential, wirtinger_conjugate, wirtinger_primal export @scalar_rule, @thunk export extern, store!, unthunk -export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Wirtinger, Zero +export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero export NO_FIELDS include("compat.jl") include("differentials/abstract_differential.jl") -include("differentials/wirtinger.jl") include("differentials/zero.jl") include("differentials/does_not_exist.jl") include("differentials/one.jl") diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 0eaac8d13..61c950f0b 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -8,44 +8,12 @@ Thus we can avoid any ambiguities. Notice: The precedence goes: - `Wirtinger, Zero, DoesNotExist, One, AbstractThunk, Composite, Any` + `Zero, DoesNotExist, One, AbstractThunk, Composite, Any` Thus each of the @eval loops creating definitions of + and * defines the combination this type with all types of lower precidence. This means each eval loops is 1 item smaller than the previous. ==# - -function Base.:*(a::Wirtinger, b::Wirtinger) - error(""" - Cannot multiply two Wirtinger objects; this error likely means a - `WirtingerRule` was inappropriately defined somewhere. Multiplication - of two Wirtinger objects is not defined because chain rule application - often expands into a non-commutative operation in the Wirtinger - calculus. To put it another way: simply given two Wirtinger objects - and no other information, we can't know "locally" which components to - conjugate in order to implement the chain rule. We could pick a - convention; for example, we could define `a::Wirtinger * b::Wirtinger` - such that we assume the chain rule application is of the form `f_a ∘ f_b` - instead of `f_b ∘ f_a`. However, picking such a convention is likely to - lead to silently incorrect derivatives due to commutativity assumptions - in downstream generic code that deals with the reals. Thus, ChainRulesCore - makes this operation an error instead. - """) -end - -function Base.:+(a::Wirtinger, b::Wirtinger) - return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate) -end - -for T in (:Zero, :DoesNotExist, :One, :AbstractThunk, :Any) - @eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero()) - @eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b - - @eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b) - @eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate) -end - - Base.:+(::Zero, b::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() for T in (:DoesNotExist, :One, :AbstractThunk, :Any) diff --git a/src/differentials/abstract_differential.jl b/src/differentials/abstract_differential.jl index fb14ecec3..d011d72d4 100644 --- a/src/differentials/abstract_differential.jl +++ b/src/differentials/abstract_differential.jl @@ -40,13 +40,3 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. @inline extern(x) = x @inline Base.conj(x::AbstractDifferential) = x - -""" - refine_differential(š’Ÿ::Type, der) - -Converts, if required, a differential object `der` -(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), -to another differential that is more suited for the domain given by the type š’Ÿ. -Often this will behave as the identity function on `der`. -""" -refine_differential(::Any, der) = der # most of the time leave it alone. diff --git a/src/differentials/wirtinger.jl b/src/differentials/wirtinger.jl deleted file mode 100644 index 1e3fb1317..000000000 --- a/src/differentials/wirtinger.jl +++ /dev/null @@ -1,44 +0,0 @@ -""" - Wirtinger(primal::Union{Number,AbstractDifferential}, - conjugate::Union{Number,AbstractDifferential}) - -Returns a `Wirtinger` instance representing the complex differential: - -``` -df = āˆ‚f/āˆ‚z * dz + āˆ‚f/āˆ‚zĢ„ * dzĢ„ -``` - -where `primal` corresponds to `āˆ‚f/āˆ‚z * dz` and `conjugate` corresponds to `āˆ‚f/āˆ‚zĢ„ * dzĢ„`. - -The two fields of the returned instance can be accessed generically via the -[`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref) methods. -""" -struct Wirtinger{P,C} <: AbstractDifferential - primal::P - conjugate::C - function Wirtinger(primal::Union{Number,AbstractDifferential}, - conjugate::Union{Number,AbstractDifferential}) - return new{typeof(primal),typeof(conjugate)}(primal, conjugate) - end -end - -wirtinger_primal(x::Wirtinger) = x.primal -wirtinger_primal(x) = x - -wirtinger_conjugate(x::Wirtinger) = x.conjugate -wirtinger_conjugate(::Any) = Zero() - -extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type.")) - -Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal), - broadcastable(w.conjugate)) - -Base.iterate(x::Wirtinger) = (x, nothing) -Base.iterate(::Wirtinger, ::Any) = nothing - -# TODO: define `conj` for` `Wirtinger` -Base.conj(x::Wirtinger) = throw(MethodError(conj, x)) - -function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) - return wirtinger_primal(w) + wirtinger_conjugate(w) -end diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index a06820e64..460989d9c 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -74,14 +74,8 @@ macro scalar_rule(call, maybe_setup, partials...) ) f = call.args[1] - # An expression that when evaluated will return the type of the input domain. - # Multiple repetitions of this expression should optimize out. But if it does not then - # may need to move its definition into the body of the `rrule`/`frule` - š’Ÿ = :(typeof(first(promote($(call.args[2:end]...))))) - - frule_expr = scalar_frule_expr(š’Ÿ, f, call, setup_stmts, inputs, partials) - rrule_expr = scalar_rrule_expr(š’Ÿ, f, call, setup_stmts, inputs, partials) - + frule_expr = scalar_frule_expr(f, call, setup_stmts, inputs, partials) + rrule_expr = scalar_rrule_expr(f, call, setup_stmts, inputs, partials) ############################################################################ # Final return: building the expression to insert in the place of this macro @@ -147,7 +141,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) return call, setup_stmts, inputs, partials end -function scalar_frule_expr(š’Ÿ, f, call, setup_stmts, inputs, partials) +function scalar_frule_expr(f, call, setup_stmts, inputs, partials) n_outputs = length(partials) n_inputs = length(inputs) @@ -156,7 +150,7 @@ function scalar_frule_expr(š’Ÿ, f, call, setup_stmts, inputs, partials) Ī”s = [Symbol(string(:Ī”, i)) for i in 1:n_inputs] pushforward_returns = map(1:n_outputs) do output_i āˆ‚s = partials[output_i].args - propagation_expr(š’Ÿ, Ī”s, āˆ‚s) + propagation_expr(Ī”s, āˆ‚s) end if n_outputs > 1 # For forward-mode we only return a tuple if output actually a tuple. @@ -182,7 +176,7 @@ function scalar_frule_expr(š’Ÿ, f, call, setup_stmts, inputs, partials) end end -function scalar_rrule_expr(š’Ÿ, f, call, setup_stmts, inputs, partials) +function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) n_outputs = length(partials) n_inputs = length(inputs) @@ -193,7 +187,7 @@ function scalar_rrule_expr(š’Ÿ, f, call, setup_stmts, inputs, partials) # 1 partial derivative per input pullback_returns = map(1:n_inputs) do input_i āˆ‚s = [partial.args[input_i] for partial in partials] - propagation_expr(š’Ÿ, Ī”s, āˆ‚s) + propagation_expr(Ī”s, āˆ‚s) end pullback = quote @@ -212,30 +206,14 @@ function scalar_rrule_expr(š’Ÿ, f, call, setup_stmts, inputs, partials) end """ - propagation_expr(š’Ÿ, Ī”s, āˆ‚s) + propagation_expr(Ī”s, āˆ‚s) Returns the expression for the propagation of the input gradient `Ī”s` though the partials `āˆ‚s`. - - š’Ÿ is an expression that when evaluated returns the type-of the input domain. - For example if the derivative is being taken at the point `1` it returns `Int`. - if it is taken at `1+1im` it returns `Complex{Int}`. - At present it is ignored for non-Wirtinger derivatives. """ -function propagation_expr(š’Ÿ, Ī”s, āˆ‚s) - wirtinger_indices = findall(āˆ‚s) do ex - Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger - end - āˆ‚s = map(esc, āˆ‚s) - if isempty(wirtinger_indices) - return standard_propagation_expr(Ī”s, āˆ‚s) - else - return wirtinger_propagation_expr(š’Ÿ, wirtinger_indices, Ī”s, āˆ‚s) - end -end - -function standard_propagation_expr(Ī”s, āˆ‚s) +function propagation_expr(Ī”s, āˆ‚s) # This is basically Ī”s ā‹… āˆ‚s + āˆ‚s = map(esc, āˆ‚s) # Notice: the thunking of `āˆ‚s[i] (potentially) saves us some computation # if `Ī”s[i]` is a `AbstractDifferential` otherwise it is computed as soon @@ -244,36 +222,6 @@ function standard_propagation_expr(Ī”s, āˆ‚s) return :(+($(āˆ‚_mul_Ī”s...))) end -function wirtinger_propagation_expr(š’Ÿ, wirtinger_indices, Ī”s, āˆ‚s) - āˆ‚_mul_Ī”s_primal = Any[] - āˆ‚_mul_Ī”s_conjugate = Any[] - āˆ‚_wirtinger_defs = Any[] - for i in 1:length(āˆ‚s) - if i in wirtinger_indices - Ī”i = Ī”s[i] - āˆ‚i = Symbol(string(:āˆ‚, i)) - push!(āˆ‚_wirtinger_defs, :($āˆ‚i = $(āˆ‚s[i]))) - āˆ‚fāˆ‚i_mul_Ī” = :(wirtinger_primal($āˆ‚i) * wirtinger_primal($Ī”i)) - āˆ‚fāˆ‚iĢ„_mul_Δ̄ = :(conj(wirtinger_conjugate($āˆ‚i)) * wirtinger_conjugate($Ī”i)) - āˆ‚fĢ„āˆ‚i_mul_Ī” = :(wirtinger_conjugate($āˆ‚i) * wirtinger_primal($Ī”i)) - āˆ‚fĢ„āˆ‚iĢ„_mul_Δ̄ = :(conj(wirtinger_primal($āˆ‚i)) * wirtinger_conjugate($Ī”i)) - push!(āˆ‚_mul_Ī”s_primal, :($āˆ‚fāˆ‚i_mul_Ī” + $āˆ‚fāˆ‚iĢ„_mul_Δ̄)) - push!(āˆ‚_mul_Ī”s_conjugate, :($āˆ‚fĢ„āˆ‚i_mul_Ī” + $āˆ‚fĢ„āˆ‚iĢ„_mul_Δ̄)) - else - āˆ‚_mul_Ī” = :(@thunk($(āˆ‚s[i])) * $(Ī”s[i])) - push!(āˆ‚_mul_Ī”s_primal, āˆ‚_mul_Ī”) - push!(āˆ‚_mul_Ī”s_conjugate, āˆ‚_mul_Ī”) - end - end - primal_sum = :(+($(āˆ‚_mul_Ī”s_primal...))) - conjugate_sum = :(+($(āˆ‚_mul_Ī”s_conjugate...))) - return quote # This will be a block, so will have value equal to last statement - $(āˆ‚_wirtinger_defs...) - w = Wirtinger($primal_sum, $conjugate_sum) - refine_differential($š’Ÿ, w) - end -end - """ propagator_name(f, propname) diff --git a/test/differentials/wirtinger.jl b/test/differentials/wirtinger.jl deleted file mode 100644 index 3bdde8e1d..000000000 --- a/test/differentials/wirtinger.jl +++ /dev/null @@ -1,19 +0,0 @@ -@testset "Wirtinger" begin - w = Wirtinger(1+1im, 2+2im) - @test wirtinger_primal(w) == 1+1im - @test wirtinger_conjugate(w) == 2+2im - @test w + w == Wirtinger(2+2im, 4+4im) - - @test w + One() == w + 1 == w + Thunk(()->1) == Wirtinger(2+1im, 2+2im) - @test w * One() == One() * w == w - @test w * 2 == 2 * w == Wirtinger(2 + 2im, 4 + 4im) - - # TODO: other + methods stack overflow - @test_throws ErrorException w*w - @test_throws ArgumentError extern(w) - for x in w - @test x === w - end - @test broadcastable(w) == w - @test_throws MethodError conj(w) -end diff --git a/test/differentials_common.jl b/test/differentials_common.jl deleted file mode 100644 index 00776d222..000000000 --- a/test/differentials_common.jl +++ /dev/null @@ -1,16 +0,0 @@ -@testset "Differential Common" begin - @testset "Refine Differential" begin - @test refine_differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2) - @test refine_differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2) - - @test refine_differential(typeof(1.2), Wirtinger(2,2)) == 4 - @test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4 - - # For most differentials, in most domains, this does nothing - for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0) - for š’Ÿ in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2])) - @test refine_differential(š’Ÿ, der) === der - end - end - end -end diff --git a/test/rules.jl b/test/rules.jl index e23680326..59175efcf 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -38,114 +38,3 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rrx == 2 @test rr1 == 1 end - - -@testset "Basic Wirtinger scalar_rule" begin - myabs2(x) = abs2(x) - @scalar_rule(myabs2(x), Wirtinger(x', x)) - - @testset "real input" begin - # even though our rule was define in terms of Wirtinger, - # pushforward result will be real as real (even if seed is Compex) - - x = rand(Float64) - f, myabs2_pushforward = frule(myabs2, x) - @test f === x^2 - - Ī” = One() - df = @inferred myabs2_pushforward(NamedTuple(), Ī”) - @test df === x + x - - Ī” = rand(Complex{Int64}) - df = @inferred myabs2_pushforward(NamedTuple(), Ī”) - @test df === Ī” * (x + x) - end - - @testset "complex input" begin - z = rand(Complex{Float64}) - f, myabs2_pushforward = frule(myabs2, z) - @test f === abs2(z) - - df = @inferred myabs2_pushforward(NamedTuple(), One()) - @test df === Wirtinger(z', z) - - Ī” = rand(Complex{Int64}) - df = @inferred myabs2_pushforward(NamedTuple(), Ī”) - @test df === Wirtinger(Ī” * z', Ī” * z) - end -end - - -@testset "Advanced Wirtinger @scalar_rule: abs_to_pow" begin - # This is based on SimeonSchaub excellent example: - # https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 - - # This is much more complex than the previous case - # as it has many different types - # depending on input, and the output types do not always agree - - abs_to_pow(x, p) = abs(x)^p - @scalar_rule( - abs_to_pow(x::Real, p), - ( - p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x), - Ī© * log(abs(x)) - ) - ) - - @scalar_rule( - abs_to_pow(x::Complex, p), - @setup(u = abs(x)), - ( - p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u), - Ī© * log(abs(x)) - ) - ) - - - f = abs_to_pow - @testset "f($x, $p)" for (x, p) in Iterators.product( - (2, 3.4, -2.1, -10+0im, 2.3-2im), - (0, 1, 2, 4.3, -2.1, 1+.2im) - ) - expected_type_df_dx = - if iszero(p) - Zero - elseif typeof(x) <: Complex - Wirtinger - elseif typeof(p) <: Complex - Complex - else - Real - end - - expected_type_df_dp = - if typeof(p) <: Real - Real - else - Complex - end - - - res = frule(f, x, p) - @test res !== nothing # Check the rule was defined - fx, f_pushforward = res - df(Ī”x, Ī”p) = f_pushforward(NamedTuple(), Ī”x, Ī”p) - - df_dx::Thunk = df(One(), Zero()) - df_dp::Thunk = df(Zero(), One()) - @test fx == f(x, p) # Check we still get the normal value, right - @test df_dx() isa expected_type_df_dx - @test df_dp() isa expected_type_df_dp - - - res = rrule(f, x, p) - @test res !== nothing # Check the rule was defined - fx, f_pullback = res - dself, df_dx, df_dp = f_pullback(One()) - @test fx == f(x, p) # Check we still get the normal value, right - @test dself == NO_FIELDS - @test df_dx() isa expected_type_df_dx - @test df_dp() isa expected_type_df_dp - end -end diff --git a/test/runtests.jl b/test/runtests.jl index a0b132f9e..b6756a86a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,14 +4,11 @@ using ChainRulesCore using LinearAlgebra: Diagonal using ChainRulesCore: extern, accumulate, accumulate!, store!, Composite, @scalar_rule, - Wirtinger, wirtinger_primal, wirtinger_conjugate, Zero, One, DoesNotExist, Thunk using Base.Broadcast: broadcastable @testset "ChainRulesCore" begin - include("differentials_common.jl") @testset "differentials" begin - include("differentials/wirtinger.jl") include("differentials/zero.jl") include("differentials/one.jl") include("differentials/thunks.jl")