From 313f65fb8f8a5e14c75d756d2c21eb1f1644c234 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 00:13:30 -0500 Subject: [PATCH 01/12] Remove Wirtinger --- docs/src/index.md | 1 - src/rulesets/Base/base.jl | 9 +------ test/rulesets/Base/base.jl | 26 +----------------- test/runtests.jl | 1 - test/test_util.jl | 55 +++++++++----------------------------- 5 files changed, 15 insertions(+), 77 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index fbda28495..9dd9b0194 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -233,7 +233,6 @@ The most important `AbstractDifferential`s when getting started are the ones abo - `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition. #### Other `AbstractDifferential`s: don't worry about them right now - - `Wirtinger`: it is complex. The docs need to be better. [Read the links in this issue](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/40). - `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10) - `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place. diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 39e2ab1e7..a3e71a461 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -2,7 +2,7 @@ @scalar_rule(zero(x), Zero()) @scalar_rule(sign(x), Zero()) -@scalar_rule(abs2(x), Wirtinger(x', x)) +@scalar_rule(abs2(x), 2x) @scalar_rule(log(x), inv(x)) @scalar_rule(log10(x), inv(x) / log(oftype(x, 10))) @scalar_rule(log2(x), inv(x) / log(oftype(x, 2))) @@ -50,14 +50,10 @@ @scalar_rule(deg2rad(x), π / oftype(x, 180)) @scalar_rule(rad2deg(x), oftype(x, 180) / π) -@scalar_rule(conj(x), Wirtinger(Zero(), One())) -@scalar_rule(adjoint(x), Wirtinger(Zero(), One())) @scalar_rule(transpose(x), One()) @scalar_rule(abs(x::Real), sign(x)) -@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω)) @scalar_rule(hypot(x::Real), sign(x)) -@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω)) @scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist())) @scalar_rule(+(x), One()) @@ -98,11 +94,8 @@ (ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u)))) @scalar_rule(fma(x, y, z), (y, x, One())) @scalar_rule(muladd(x, y, z), (y, x, One())) -@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u)) @scalar_rule(angle(x::Real), Zero()) -@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2)) @scalar_rule(real(x::Real), One()) -@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2)) @scalar_rule(imag(x::Real), Zero()) # product rule requires special care for arguments where `mul` is non-commutative diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 979cc9925..2bd4eb187 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -62,7 +62,7 @@ end # Trig @testset "math" begin - for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im) + for x in (-0.1, 6.4) test_scalar(deg2rad, x) test_scalar(rad2deg, x) @@ -73,30 +73,6 @@ test_scalar(exp10, x) x isa Real && test_scalar(cbrt, x) - if (x isa Real && x >= 0) || x isa Complex - # this check is needed because these have discontinuities between - # `-10 + im*eps()` and `-10 - im*eps()` - should_test_wirtinger = imag(x) != 0 && real(x) < 0 - test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger) - test_scalar(log, x; test_wirtinger=should_test_wirtinger) - test_scalar(log2, x; test_wirtinger=should_test_wirtinger) - test_scalar(log10, x; test_wirtinger=should_test_wirtinger) - test_scalar(log1p, x; test_wirtinger=should_test_wirtinger) - end - end - end - - @testset "Unary complex functions" begin - for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im) - test_scalar(real, x) - test_scalar(imag, x) - - test_scalar(abs, x) - test_scalar(hypot, x) - - test_scalar(angle, x) - test_scalar(abs2, x) - test_scalar(conj, x) end end diff --git a/test/runtests.jl b/test/runtests.jl index aa471c266..a0e5d74f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,6 @@ using Test # For testing purposes we use a lot of using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, - Wirtinger, wirtinger_primal, wirtinger_conjugate, Zero, One, DoesNotExist, Thunk, AbstractDifferential Random.seed!(1) # Set seed that all testsets should reset to. diff --git a/test/test_util.jl b/test/test_util.jl index 07d8e8277..62b9a03a8 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -7,7 +7,7 @@ const _fdm = central_fdm(5, 1) """ - test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), test_wirtinger=x isa Complex, kwargs...) + test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) Given a function `f` with scalar input an scalar output, perform finite differencing checks, at input point `x` to confirm that there are correct ChainRules provided. @@ -16,49 +16,22 @@ at input point `x` to confirm that there are correct ChainRules provided. - `f`: Function for which the `frule` and `rrule` should be tested. - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). -- `test_wirtinger`: test whether the wirtinger derivative is correct, too - -All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`. +All keyword arguments except for `fdm` is passed to `isapprox`. """ -function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...) +function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) ensure_not_running_on_functor(f, "test_scalar") - @testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule) - res = rule(f, x) - @test res !== nothing # Check the rule was defined - fx, prop_rule = res + r_res = rrule(f, x) + f_res = frule(f, x, Zero(), 1) + @test r_res !== f_res !== nothing # Check the rule was defined + r_fx, prop_rule = r_res + f_fx, f_∂x = f_res + @testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in ((rrule, r_fx, prop_rule(1)), (frule, f_fx, f_∂x)) @test fx == f(x) # Check we still get the normal value, right if rule == rrule - ∂self, ∂x = prop_rule(1) + ∂self, ∂x = ∂x @test ∂self === NO_FIELDS - else # rule == frule - # Got to input extra first aguement for internals - # But it is only a dummy since this is not a functor - ∂x = prop_rule(NamedTuple(), 1) - end - - - # Check that we get the derivative right: - if !test_wirtinger - @test isapprox( - ∂x, fdm(f, x); - rtol=rtol, atol=atol, kwargs... - ) - else - # For complex arguments, also check if the wirtinger derivative is correct - ∂Re = fdm(ϵ -> f(x + ϵ), 0) - ∂Im = fdm(ϵ -> f(x + im*ϵ), 0) - ∂ = 0.5(∂Re - im*∂Im) - ∂̅ = 0.5(∂Re + im*∂Im) - @test isapprox( - wirtinger_primal(∂x), ∂; - rtol=rtol, atol=atol, kwargs... - ) - @test isapprox( - wirtinger_conjugate(∂x), ∂̅; - rtol=rtol, atol=atol, kwargs... - ) end end end @@ -92,7 +65,9 @@ end function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) ensure_not_running_on_functor(f, "frule_test") xs, ẋs = collect(zip(xẋs...)) - Ω, pushforward = ChainRules.frule(f, xs...) + dself = Zero() + Ω, dΩ_ad = ChainRules.frule(f, xs..., dself, ẋs...) + dΩ_ad = dΩ_ad .+ 0 @test f(xs...) == Ω dΩ_ad = pushforward(NamedTuple(), ẋs...) @@ -204,10 +179,6 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm end end -function Base.isapprox(ad::Wirtinger, fd; kwargs...) - error("Finite differencing with Wirtinger rules not implemented") -end - function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) error("Tried to differentiate w.r.t. a `DoesNotExist`") end From 08191e2b456136104627f531948f26374de1e17b Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 00:09:41 -0500 Subject: [PATCH 02/12] Fuse frule --- Project.toml | 4 ++- src/rulesets/Base/base.jl | 14 +++-------- src/rulesets/Base/broadcast.jl | 12 ++++----- src/rulesets/Base/mapreduce.jl | 7 ++---- src/rulesets/LinearAlgebra/blas.jl | 16 ++++-------- src/rulesets/LinearAlgebra/dense.jl | 39 +++++++++-------------------- test/rulesets/Base/base.jl | 36 +++++++++++++++++++------- test/rulesets/Base/broadcast.jl | 4 +-- test/test_util.jl | 1 - 9 files changed, 59 insertions(+), 74 deletions(-) diff --git a/Project.toml b/Project.toml index 8a7a571f2..4c1bc5783 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,9 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.4" +Reexport = "0.2" +Requires = "0.5.2" +ChainRulesCore = "0.5" FiniteDifferences = "^0.7" Reexport = "0.2" Requires = "0.5.2, 1" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index a3e71a461..66332f71d 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -100,11 +100,8 @@ # product rule requires special care for arguments where `mul` is non-commutative -function frule(::typeof(*), x::Number, y::Number) - function times_pushforward(_, Δx, Δy) - return Δx * y + x * Δy - end - return x * y, times_pushforward +function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy) + return x * y, Δx * y + x * Δy end function rrule(::typeof(*), x::Number, y::Number) @@ -114,11 +111,8 @@ function rrule(::typeof(*), x::Number, y::Number) return x * y, times_pullback end -function frule(::typeof(identity), x) - function identity_pushforward(_, ẏ) - return ẏ - end - return x, identity_pushforward +function frule(::typeof(identity), x, _, ẏ) + return x, ẏ end function rrule(::typeof(identity), x) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 2807d63a4..ff4d6a546 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -6,19 +6,17 @@ https://github.com/JuliaLang/julia/issues/22129. =# function _cast_diff(f, x) function element_rule(u) - fu, du = frule(f, u) - fu, extern(du(NamedTuple(), One())) + dself = Zero() + fu, du = frule(f, u, dself, One()) + fu, extern(du) end results = broadcast(element_rule, x) return first.(results), last.(results) end -function frule(::typeof(broadcast), f, x) +function frule(::typeof(broadcast), f, x, _, Δf, Δx) Ω, ∂x = _cast_diff(f, x) - function broadcast_pushforward(_, Δf, Δx) - return Δx .* ∂x - end - return Ω, broadcast_pushforward + return Ω, Δx .* ∂x end function rrule(::typeof(broadcast), f, x) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index b2560274a..ad73f1411 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -50,11 +50,8 @@ end ##### `sum` ##### -function frule(::typeof(sum), x) - function sum_pushforward(_, ẋ) - return sum(ẋ) - end - return sum(x), sum_pushforward +function frule(::typeof(sum), x, _, ẋ) + return sum(x), sum(ẋ) end function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 261cbe661..1aec8ee3d 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x))) ##### `BLAS.dot` ##### -frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y) +frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy) rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y) @@ -35,12 +35,9 @@ end ##### `BLAS.nrm2` ##### -function frule(::typeof(BLAS.nrm2), x) +function frule(::typeof(BLAS.nrm2), x, _, Δ) Ω = BLAS.nrm2(x) - function nrm2_pushforward(_, Δx) - return sum(Δx * cast(@thunk(x * inv(Ω)))) - end - return Ω, nrm2_pushforward + return Ω, sum(Δx * cast(@thunk(x * inv(Ω)))) end function rrule(::typeof(BLAS.nrm2), x) @@ -70,11 +67,8 @@ end ##### `BLAS.asum` ##### -function frule(::typeof(BLAS.asum), x) - function asum_pushforward(_, Δx) - return sum(cast(sign, x) * Δx) - end - return BLAS.asum(x), asum_pushforward +function frule(::typeof(BLAS.asum), x, _, Δx) + return BLAS.asum(x), sum(cast(sign, x) * Δx) end function rrule(::typeof(BLAS.asum), x) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index f58015627..a1cf60191 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -8,11 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} ##### `dot` ##### -function frule(::typeof(dot), x, y) - function dot_pushforward(Δself, Δx, Δy) - return sum(Δx .* y) + sum(x .* Δy) - end - return dot(x, y), dot_pushforward +function frule(::typeof(dot), x, y, _, Δx, Δy) + return dot(x, y), sum(Δx .* y) + sum(x .* Δy) end function rrule(::typeof(dot), x, y) @@ -26,13 +23,10 @@ end ##### `inv` ##### -function frule(::typeof(inv), x::AbstractArray) +function frule(::typeof(inv), x::AbstractArray, _, Δx) Ω = inv(x) m = @thunk(-Ω) - function inv_pushforward(_, Δx) - return m * Δx * Ω - end - return Ω, inv_pushforward + return Ω, m * Δx * Ω end function rrule(::typeof(inv), x::AbstractArray) @@ -48,14 +42,11 @@ end ##### `det` ##### -function frule(::typeof(det), x) +function frule(::typeof(det), x, _, ẋ) Ω = det(x) - function det_pushforward(_, ẋ) - # TODO Performance optimization: probably there is an efficent - # way to compute this trace without during the full compution within - return Ω * tr(inv(x) * ẋ) - end - return Ω, det_pushforward + # TODO Performance optimization: probably there is an efficent + # way to compute this trace without during the full compution within + return Ω, Ω * tr(inv(x) * ẋ) end function rrule(::typeof(det), x) @@ -70,12 +61,9 @@ end ##### `logdet` ##### -function frule(::typeof(logdet), x) +function frule(::typeof(logdet), x, _, Δx) Ω = logdet(x) - function logdet_pushforward(_, Δx) - return tr(inv(x) * Δx) - end - return Ω, logdet_pushforward + return Ω, tr(inv(x) * Δx) end function rrule(::typeof(logdet), x) @@ -90,11 +78,8 @@ end ##### `trace` ##### -function frule(::typeof(tr), x) - function tr_pushforward(_, Δx) - return tr(Δx) - end - return tr(x), tr_pushforward +function frule(::typeof(tr), x, _, Δx) + return tr(x), tr(Δx) end function rrule(::typeof(tr), x) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 2bd4eb187..ea8c4c0e9 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -51,13 +51,32 @@ test_scalar(acscd, 1/x) test_scalar(acotd, 1/x) end - - @testset "sincos" begin - x, Δx, x̄ = randn(3) - Δz = (randn(), randn()) - - frule_test(sincos, (x, Δx)) - rrule_test(sincos, Δz, (x, x̄)) + @testset "Multivariate" begin + @testset "atan2" begin + # https://en.wikipedia.org/wiki/Atan2 + x, y = rand(2) + ratan = atan(x, y) + u = x^2 + y^2 + datan = y/u - 2x/u + + r, ṙ = frule(atan, x, y, Zero(), 1, 2) + @test r === ratan + @test ṙ === datan + + r, pullback = rrule(atan, x, y) + @test r === ratan + dself, df1, df2 = pullback(1) + @test dself == NO_FIELDS + @test df1 + 2df2 === datan + end + + @testset "sincos" begin + x, Δx, x̄ = randn(3) + Δz = (randn(), randn()) + + frule_test(sincos, (x, Δx)) + rrule_test(sincos, Δz, (x, x̄)) + end end end # Trig @@ -128,8 +147,7 @@ _, x̄ = pb(10.5) @test extern(x̄) == 0 - _, pf = frule(sign, 0.0) - ẏ = pf(NamedTuple(), 10.5) + _, ẏ = frule(sign, 0.0, Zero(), 10.5) @test extern(ẏ) == 0 end end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 17f2663ab..ac87540d3 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -22,10 +22,8 @@ end @testset "frule" begin x = rand(3, 3) - y, pushforward = frule(broadcast, sin, x) + y, ẏ = frule(broadcast, sin, x, Zero(), Zero(), One()) @test y == sin.(x) - - ẏ = pushforward(NamedTuple(), NamedTuple(), One()) @test extern(ẏ) == cos.(x) end end diff --git a/test/test_util.jl b/test/test_util.jl index 62b9a03a8..a540d8dd4 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -69,7 +69,6 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm Ω, dΩ_ad = ChainRules.frule(f, xs..., dself, ẋs...) dΩ_ad = dΩ_ad .+ 0 @test f(xs...) == Ω - dΩ_ad = pushforward(NamedTuple(), ẋs...) # Correctness testing via finite differencing. dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs)) From d393d12c79451c95d3ed3d0ba3cf32d04a420e3e Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 01:06:13 -0500 Subject: [PATCH 03/12] Fix Project.toml Oops, accidentally removed fdm test --- Project.toml | 2 -- test/test_util.jl | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 4c1bc5783..cf805f400 100644 --- a/Project.toml +++ b/Project.toml @@ -10,8 +10,6 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -Reexport = "0.2" -Requires = "0.5.2" ChainRulesCore = "0.5" FiniteDifferences = "^0.7" Reexport = "0.2" diff --git a/test/test_util.jl b/test/test_util.jl index a540d8dd4..5c6148d94 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -33,6 +33,8 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) ∂self, ∂x = ∂x @test ∂self === NO_FIELDS end + @test isapprox(∂x, fdm(f, x); + rtol=rtol, atol=atol, kwargs...) end end From 16ec82b8ed2b25f52b260856f593358dd8683188 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 12:48:50 -0500 Subject: [PATCH 04/12] Add back `conj`, `adjoint`, and `abs2` for real --- src/rulesets/Base/base.jl | 2 ++ test/rulesets/Base/base.jl | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 66332f71d..2382af7fa 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -50,6 +50,8 @@ @scalar_rule(deg2rad(x), π / oftype(x, 180)) @scalar_rule(rad2deg(x), oftype(x, 180) / π) +@scalar_rule(conj(x), One()) +@scalar_rule(adjoint(x), One()) @scalar_rule(transpose(x), One()) @scalar_rule(abs(x::Real), sign(x)) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index ea8c4c0e9..67e40132c 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -91,6 +91,10 @@ test_scalar(exp2, x) test_scalar(exp10, x) + test_scalar(conj, x) + test_scalar(adjoint, x) + test_scalar(abs2, x) + x isa Real && test_scalar(cbrt, x) end end From 0835a59c7c17f592b9301f5fb3ace5fd149058dd Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 13:57:38 -0500 Subject: [PATCH 05/12] New release --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cf805f400..f328da04f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.2.5" +version = "0.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From d56978d02ca4b6309c7939632dd0069068ac9258 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 16:34:35 -0500 Subject: [PATCH 06/12] Add type constraint to conj and adjoint --- src/rulesets/Base/base.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 2382af7fa..b36307b9d 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -50,8 +50,8 @@ @scalar_rule(deg2rad(x), π / oftype(x, 180)) @scalar_rule(rad2deg(x), oftype(x, 180) / π) -@scalar_rule(conj(x), One()) -@scalar_rule(adjoint(x), One()) +@scalar_rule(conj(x::Real), One()) +@scalar_rule(adjoint(x::Real), One()) @scalar_rule(transpose(x), One()) @scalar_rule(abs(x::Real), sign(x)) From 34a1bbe9b6253eff25e350371dd198e0344218ef Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 18:18:52 -0500 Subject: [PATCH 07/12] Remove the `.+ 0` hack Ref: https://github.com/JuliaDiff/ChainRulesCore.jl/pull/90 --- Project.toml | 2 +- test/test_util.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f328da04f..abaa65b2b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.5" +ChainRulesCore = "0.5.1" FiniteDifferences = "^0.7" Reexport = "0.2" Requires = "0.5.2, 1" diff --git a/test/test_util.jl b/test/test_util.jl index 5c6148d94..497cc9655 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -69,7 +69,6 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm xs, ẋs = collect(zip(xẋs...)) dself = Zero() Ω, dΩ_ad = ChainRules.frule(f, xs..., dself, ẋs...) - dΩ_ad = dΩ_ad .+ 0 @test f(xs...) == Ω # Correctness testing via finite differencing. From e0384beba194eb6a0b2e4860b6fcfba54071922d Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 16:35:30 -0500 Subject: [PATCH 08/12] Address code review comments --- src/rulesets/LinearAlgebra/blas.jl | 9 ++++-- src/rulesets/LinearAlgebra/dense.jl | 6 ++-- test/rulesets/Base/base.jl | 43 +++++++++++++++-------------- test/test_util.jl | 8 ++++-- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 1aec8ee3d..fb0243dce 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -37,7 +37,7 @@ end function frule(::typeof(BLAS.nrm2), x, _, Δ) Ω = BLAS.nrm2(x) - return Ω, sum(Δx * cast(@thunk(x * inv(Ω)))) + return Ω, sum(Δx .* @thunk(x * inv(Ω))) end function rrule(::typeof(BLAS.nrm2), x) @@ -68,12 +68,15 @@ end ##### function frule(::typeof(BLAS.asum), x, _, Δx) - return BLAS.asum(x), sum(cast(sign, x) * Δx) + return BLAS.asum(x), sum(zip(x, Δx)) do xs + x, Δx = xs + return sign(x) * Δx + end end function rrule(::typeof(BLAS.asum), x) function asum_pullback(ΔΩ) - return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x))) + return (NO_FIELDS, @thunk(ΔΩ * sign.(x))) end return BLAS.asum(x), asum_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index a1cf60191..a5191f91a 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -25,15 +25,13 @@ end function frule(::typeof(inv), x::AbstractArray, _, Δx) Ω = inv(x) - m = @thunk(-Ω) - return Ω, m * Δx * Ω + return Ω, -Ω * Δx * Ω end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) - m = @thunk(-Ω') function inv_pullback(ΔΩ) - return NO_FIELDS, m * ΔΩ * Ω' + return NO_FIELDS, -Ω' * ΔΩ * Ω' end return Ω, inv_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 67e40132c..c376500af 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -52,24 +52,6 @@ test_scalar(acotd, 1/x) end @testset "Multivariate" begin - @testset "atan2" begin - # https://en.wikipedia.org/wiki/Atan2 - x, y = rand(2) - ratan = atan(x, y) - u = x^2 + y^2 - datan = y/u - 2x/u - - r, ṙ = frule(atan, x, y, Zero(), 1, 2) - @test r === ratan - @test ṙ === datan - - r, pullback = rrule(atan, x, y) - @test r === ratan - dself, df1, df2 = pullback(1) - @test dself == NO_FIELDS - @test df1 + 2df2 === datan - end - @testset "sincos" begin x, Δx, x̄ = randn(3) Δz = (randn(), randn()) @@ -91,11 +73,30 @@ test_scalar(exp2, x) test_scalar(exp10, x) + test_scalar(cbrt, x) + + if x >= 0 + test_scalar(sqrt, x) + test_scalar(log, x) + test_scalar(log2, x) + test_scalar(log10, x) + test_scalar(log1p, x) + end + end + end + + @testset "Unary complex functions" begin + for x in (-4.1, 6.4) + test_scalar(real, x) + test_scalar(imag, x) + + test_scalar(abs, x) + test_scalar(hypot, x) + + test_scalar(angle, x) + test_scalar(abs2, x) test_scalar(conj, x) test_scalar(adjoint, x) - test_scalar(abs2, x) - - x isa Real && test_scalar(cbrt, x) end end diff --git a/test/test_util.jl b/test/test_util.jl index 497cc9655..ae3fcd611 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -26,15 +26,17 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) @test r_res !== f_res !== nothing # Check the rule was defined r_fx, prop_rule = r_res f_fx, f_∂x = f_res - @testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in ((rrule, r_fx, prop_rule(1)), (frule, f_fx, f_∂x)) + @testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in ( + (rrule, r_fx, prop_rule(1)), + (frule, f_fx, f_∂x) + ) @test fx == f(x) # Check we still get the normal value, right if rule == rrule ∂self, ∂x = ∂x @test ∂self === NO_FIELDS end - @test isapprox(∂x, fdm(f, x); - rtol=rtol, atol=atol, kwargs...) + @test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...) end end From 3fe59b4f3a693e560c2c07ec7a76f953e7347226 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 16:40:40 -0500 Subject: [PATCH 09/12] Remove BLAS.asum(x) rules --- src/rulesets/LinearAlgebra/blas.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index fb0243dce..f3d187fce 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -67,20 +67,6 @@ end ##### `BLAS.asum` ##### -function frule(::typeof(BLAS.asum), x, _, Δx) - return BLAS.asum(x), sum(zip(x, Δx)) do xs - x, Δx = xs - return sign(x) * Δx - end -end - -function rrule(::typeof(BLAS.asum), x) - function asum_pullback(ΔΩ) - return (NO_FIELDS, @thunk(ΔΩ * sign.(x))) - end - return BLAS.asum(x), asum_pullback -end - function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) function asum_pullback(ΔΩ) From 61b425b71b26dd0a2794cebbcd290aa26a02aebb Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 16:45:10 -0500 Subject: [PATCH 10/12] Fix frule for `BLAS.nrm2` --- src/rulesets/LinearAlgebra/blas.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index f3d187fce..6a3229b7e 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -35,7 +35,7 @@ end ##### `BLAS.nrm2` ##### -function frule(::typeof(BLAS.nrm2), x, _, Δ) +function frule(::typeof(BLAS.nrm2), x, _, Δx) Ω = BLAS.nrm2(x) return Ω, sum(Δx .* @thunk(x * inv(Ω))) end From 63b9f7e8efe8a25f1ea2c4fb048c798a9613b04e Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 16:54:47 -0500 Subject: [PATCH 11/12] Revert "Remove BLAS.asum(x) rules" This reverts commit 3fe59b4f3a693e560c2c07ec7a76f953e7347226. --- src/rulesets/LinearAlgebra/blas.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 6a3229b7e..bb0734e52 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -67,6 +67,20 @@ end ##### `BLAS.asum` ##### +function frule(::typeof(BLAS.asum), x, _, Δx) + return BLAS.asum(x), sum(zip(x, Δx)) do xs + x, Δx = xs + return sign(x) * Δx + end +end + +function rrule(::typeof(BLAS.asum), x) + function asum_pullback(ΔΩ) + return (NO_FIELDS, @thunk(ΔΩ * sign.(x))) + end + return BLAS.asum(x), asum_pullback +end + function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) function asum_pullback(ΔΩ) From c1a4451032856f5262cf8d17e91f0e75d97698cd Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 16:59:12 -0500 Subject: [PATCH 12/12] =?UTF-8?q?Fix=20asum=20for=20the=20case=20where=20?= =?UTF-8?q?=CE=94x=20is=20a=20matrix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rulesets/LinearAlgebra/blas.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index bb0734e52..77940fbed 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -68,10 +68,7 @@ end ##### function frule(::typeof(BLAS.asum), x, _, Δx) - return BLAS.asum(x), sum(zip(x, Δx)) do xs - x, Δx = xs - return sign(x) * Δx - end + return BLAS.asum(x), sum(sign.(x) .* Δx) end function rrule(::typeof(BLAS.asum), x)