diff --git a/Project.toml b/Project.toml index 8a7a571f2..abaa65b2b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.2.5" +version = "0.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.4" +ChainRulesCore = "0.5.1" FiniteDifferences = "^0.7" Reexport = "0.2" Requires = "0.5.2, 1" diff --git a/docs/src/index.md b/docs/src/index.md index fbda28495..9dd9b0194 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -233,7 +233,6 @@ The most important `AbstractDifferential`s when getting started are the ones abo - `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition. #### Other `AbstractDifferential`s: don't worry about them right now - - `Wirtinger`: it is complex. The docs need to be better. [Read the links in this issue](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/40). - `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10) - `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place. diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 39e2ab1e7..b36307b9d 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -2,7 +2,7 @@ @scalar_rule(zero(x), Zero()) @scalar_rule(sign(x), Zero()) -@scalar_rule(abs2(x), Wirtinger(x', x)) +@scalar_rule(abs2(x), 2x) @scalar_rule(log(x), inv(x)) @scalar_rule(log10(x), inv(x) / log(oftype(x, 10))) @scalar_rule(log2(x), inv(x) / log(oftype(x, 2))) @@ -50,14 +50,12 @@ @scalar_rule(deg2rad(x), π / oftype(x, 180)) @scalar_rule(rad2deg(x), oftype(x, 180) / π) -@scalar_rule(conj(x), Wirtinger(Zero(), One())) -@scalar_rule(adjoint(x), Wirtinger(Zero(), One())) +@scalar_rule(conj(x::Real), One()) +@scalar_rule(adjoint(x::Real), One()) @scalar_rule(transpose(x), One()) @scalar_rule(abs(x::Real), sign(x)) -@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω)) @scalar_rule(hypot(x::Real), sign(x)) -@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω)) @scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist())) @scalar_rule(+(x), One()) @@ -98,20 +96,14 @@ (ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u)))) @scalar_rule(fma(x, y, z), (y, x, One())) @scalar_rule(muladd(x, y, z), (y, x, One())) -@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u)) @scalar_rule(angle(x::Real), Zero()) -@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2)) @scalar_rule(real(x::Real), One()) -@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2)) @scalar_rule(imag(x::Real), Zero()) # product rule requires special care for arguments where `mul` is non-commutative -function frule(::typeof(*), x::Number, y::Number) - function times_pushforward(_, Δx, Δy) - return Δx * y + x * Δy - end - return x * y, times_pushforward +function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy) + return x * y, Δx * y + x * Δy end function rrule(::typeof(*), x::Number, y::Number) @@ -121,11 +113,8 @@ function rrule(::typeof(*), x::Number, y::Number) return x * y, times_pullback end -function frule(::typeof(identity), x) - function identity_pushforward(_, ẏ) - return ẏ - end - return x, identity_pushforward +function frule(::typeof(identity), x, _, ẏ) + return x, ẏ end function rrule(::typeof(identity), x) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 2807d63a4..ff4d6a546 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -6,19 +6,17 @@ https://github.com/JuliaLang/julia/issues/22129. =# function _cast_diff(f, x) function element_rule(u) - fu, du = frule(f, u) - fu, extern(du(NamedTuple(), One())) + dself = Zero() + fu, du = frule(f, u, dself, One()) + fu, extern(du) end results = broadcast(element_rule, x) return first.(results), last.(results) end -function frule(::typeof(broadcast), f, x) +function frule(::typeof(broadcast), f, x, _, Δf, Δx) Ω, ∂x = _cast_diff(f, x) - function broadcast_pushforward(_, Δf, Δx) - return Δx .* ∂x - end - return Ω, broadcast_pushforward + return Ω, Δx .* ∂x end function rrule(::typeof(broadcast), f, x) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index b2560274a..ad73f1411 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -50,11 +50,8 @@ end ##### `sum` ##### -function frule(::typeof(sum), x) - function sum_pushforward(_, ẋ) - return sum(ẋ) - end - return sum(x), sum_pushforward +function frule(::typeof(sum), x, _, ẋ) + return sum(x), sum(ẋ) end function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 261cbe661..77940fbed 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x))) ##### `BLAS.dot` ##### -frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y) +frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy) rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y) @@ -35,12 +35,9 @@ end ##### `BLAS.nrm2` ##### -function frule(::typeof(BLAS.nrm2), x) +function frule(::typeof(BLAS.nrm2), x, _, Δx) Ω = BLAS.nrm2(x) - function nrm2_pushforward(_, Δx) - return sum(Δx * cast(@thunk(x * inv(Ω)))) - end - return Ω, nrm2_pushforward + return Ω, sum(Δx .* @thunk(x * inv(Ω))) end function rrule(::typeof(BLAS.nrm2), x) @@ -70,16 +67,13 @@ end ##### `BLAS.asum` ##### -function frule(::typeof(BLAS.asum), x) - function asum_pushforward(_, Δx) - return sum(cast(sign, x) * Δx) - end - return BLAS.asum(x), asum_pushforward +function frule(::typeof(BLAS.asum), x, _, Δx) + return BLAS.asum(x), sum(sign.(x) .* Δx) end function rrule(::typeof(BLAS.asum), x) function asum_pullback(ΔΩ) - return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x))) + return (NO_FIELDS, @thunk(ΔΩ * sign.(x))) end return BLAS.asum(x), asum_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index f58015627..a5191f91a 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -8,11 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} ##### `dot` ##### -function frule(::typeof(dot), x, y) - function dot_pushforward(Δself, Δx, Δy) - return sum(Δx .* y) + sum(x .* Δy) - end - return dot(x, y), dot_pushforward +function frule(::typeof(dot), x, y, _, Δx, Δy) + return dot(x, y), sum(Δx .* y) + sum(x .* Δy) end function rrule(::typeof(dot), x, y) @@ -26,20 +23,15 @@ end ##### `inv` ##### -function frule(::typeof(inv), x::AbstractArray) +function frule(::typeof(inv), x::AbstractArray, _, Δx) Ω = inv(x) - m = @thunk(-Ω) - function inv_pushforward(_, Δx) - return m * Δx * Ω - end - return Ω, inv_pushforward + return Ω, -Ω * Δx * Ω end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) - m = @thunk(-Ω') function inv_pullback(ΔΩ) - return NO_FIELDS, m * ΔΩ * Ω' + return NO_FIELDS, -Ω' * ΔΩ * Ω' end return Ω, inv_pullback end @@ -48,14 +40,11 @@ end ##### `det` ##### -function frule(::typeof(det), x) +function frule(::typeof(det), x, _, ẋ) Ω = det(x) - function det_pushforward(_, ẋ) - # TODO Performance optimization: probably there is an efficent - # way to compute this trace without during the full compution within - return Ω * tr(inv(x) * ẋ) - end - return Ω, det_pushforward + # TODO Performance optimization: probably there is an efficent + # way to compute this trace without during the full compution within + return Ω, Ω * tr(inv(x) * ẋ) end function rrule(::typeof(det), x) @@ -70,12 +59,9 @@ end ##### `logdet` ##### -function frule(::typeof(logdet), x) +function frule(::typeof(logdet), x, _, Δx) Ω = logdet(x) - function logdet_pushforward(_, Δx) - return tr(inv(x) * Δx) - end - return Ω, logdet_pushforward + return Ω, tr(inv(x) * Δx) end function rrule(::typeof(logdet), x) @@ -90,11 +76,8 @@ end ##### `trace` ##### -function frule(::typeof(tr), x) - function tr_pushforward(_, Δx) - return tr(Δx) - end - return tr(x), tr_pushforward +function frule(::typeof(tr), x, _, Δx) + return tr(x), tr(Δx) end function rrule(::typeof(tr), x) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 979cc9925..c376500af 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -51,18 +51,19 @@ test_scalar(acscd, 1/x) test_scalar(acotd, 1/x) end - - @testset "sincos" begin - x, Δx, x̄ = randn(3) - Δz = (randn(), randn()) + @testset "Multivariate" begin + @testset "sincos" begin + x, Δx, x̄ = randn(3) + Δz = (randn(), randn()) - frule_test(sincos, (x, Δx)) - rrule_test(sincos, Δz, (x, x̄)) + frule_test(sincos, (x, Δx)) + rrule_test(sincos, Δz, (x, x̄)) + end end end # Trig @testset "math" begin - for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im) + for x in (-0.1, 6.4) test_scalar(deg2rad, x) test_scalar(rad2deg, x) @@ -72,22 +73,20 @@ test_scalar(exp2, x) test_scalar(exp10, x) - x isa Real && test_scalar(cbrt, x) - if (x isa Real && x >= 0) || x isa Complex - # this check is needed because these have discontinuities between - # `-10 + im*eps()` and `-10 - im*eps()` - should_test_wirtinger = imag(x) != 0 && real(x) < 0 - test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger) - test_scalar(log, x; test_wirtinger=should_test_wirtinger) - test_scalar(log2, x; test_wirtinger=should_test_wirtinger) - test_scalar(log10, x; test_wirtinger=should_test_wirtinger) - test_scalar(log1p, x; test_wirtinger=should_test_wirtinger) + test_scalar(cbrt, x) + + if x >= 0 + test_scalar(sqrt, x) + test_scalar(log, x) + test_scalar(log2, x) + test_scalar(log10, x) + test_scalar(log1p, x) end end end @testset "Unary complex functions" begin - for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im) + for x in (-4.1, 6.4) test_scalar(real, x) test_scalar(imag, x) @@ -97,6 +96,7 @@ test_scalar(angle, x) test_scalar(abs2, x) test_scalar(conj, x) + test_scalar(adjoint, x) end end @@ -152,8 +152,7 @@ _, x̄ = pb(10.5) @test extern(x̄) == 0 - _, pf = frule(sign, 0.0) - ẏ = pf(NamedTuple(), 10.5) + _, ẏ = frule(sign, 0.0, Zero(), 10.5) @test extern(ẏ) == 0 end end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 17f2663ab..ac87540d3 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -22,10 +22,8 @@ end @testset "frule" begin x = rand(3, 3) - y, pushforward = frule(broadcast, sin, x) + y, ẏ = frule(broadcast, sin, x, Zero(), Zero(), One()) @test y == sin.(x) - - ẏ = pushforward(NamedTuple(), NamedTuple(), One()) @test extern(ẏ) == cos.(x) end end diff --git a/test/runtests.jl b/test/runtests.jl index aa471c266..a0e5d74f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,6 @@ using Test # For testing purposes we use a lot of using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, - Wirtinger, wirtinger_primal, wirtinger_conjugate, Zero, One, DoesNotExist, Thunk, AbstractDifferential Random.seed!(1) # Set seed that all testsets should reset to. diff --git a/test/test_util.jl b/test/test_util.jl index 07d8e8277..ae3fcd611 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -7,7 +7,7 @@ const _fdm = central_fdm(5, 1) """ - test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), test_wirtinger=x isa Complex, kwargs...) + test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) Given a function `f` with scalar input an scalar output, perform finite differencing checks, at input point `x` to confirm that there are correct ChainRules provided. @@ -16,50 +16,27 @@ at input point `x` to confirm that there are correct ChainRules provided. - `f`: Function for which the `frule` and `rrule` should be tested. - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). -- `test_wirtinger`: test whether the wirtinger derivative is correct, too - -All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`. +All keyword arguments except for `fdm` is passed to `isapprox`. """ -function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...) +function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) ensure_not_running_on_functor(f, "test_scalar") - @testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule) - res = rule(f, x) - @test res !== nothing # Check the rule was defined - fx, prop_rule = res + r_res = rrule(f, x) + f_res = frule(f, x, Zero(), 1) + @test r_res !== f_res !== nothing # Check the rule was defined + r_fx, prop_rule = r_res + f_fx, f_∂x = f_res + @testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in ( + (rrule, r_fx, prop_rule(1)), + (frule, f_fx, f_∂x) + ) @test fx == f(x) # Check we still get the normal value, right if rule == rrule - ∂self, ∂x = prop_rule(1) + ∂self, ∂x = ∂x @test ∂self === NO_FIELDS - else # rule == frule - # Got to input extra first aguement for internals - # But it is only a dummy since this is not a functor - ∂x = prop_rule(NamedTuple(), 1) - end - - - # Check that we get the derivative right: - if !test_wirtinger - @test isapprox( - ∂x, fdm(f, x); - rtol=rtol, atol=atol, kwargs... - ) - else - # For complex arguments, also check if the wirtinger derivative is correct - ∂Re = fdm(ϵ -> f(x + ϵ), 0) - ∂Im = fdm(ϵ -> f(x + im*ϵ), 0) - ∂ = 0.5(∂Re - im*∂Im) - ∂̅ = 0.5(∂Re + im*∂Im) - @test isapprox( - wirtinger_primal(∂x), ∂; - rtol=rtol, atol=atol, kwargs... - ) - @test isapprox( - wirtinger_conjugate(∂x), ∂̅; - rtol=rtol, atol=atol, kwargs... - ) end + @test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...) end end @@ -92,9 +69,9 @@ end function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) ensure_not_running_on_functor(f, "frule_test") xs, ẋs = collect(zip(xẋs...)) - Ω, pushforward = ChainRules.frule(f, xs...) + dself = Zero() + Ω, dΩ_ad = ChainRules.frule(f, xs..., dself, ẋs...) @test f(xs...) == Ω - dΩ_ad = pushforward(NamedTuple(), ẋs...) # Correctness testing via finite differencing. dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs)) @@ -204,10 +181,6 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm end end -function Base.isapprox(ad::Wirtinger, fd; kwargs...) - error("Finite differencing with Wirtinger rules not implemented") -end - function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) error("Tried to differentiate w.r.t. a `DoesNotExist`") end