From 58f8881ffeb343a022a0ec22ad3b31bc582f3364 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 10 Jan 2020 19:28:06 +0000 Subject: [PATCH 1/2] WIP removing wirtinger --- Project.toml | 4 ++-- src/rulesets/Base/base.jl | 8 +++----- test/rulesets/Base/base.jl | 16 ++++++++-------- test/runtests.jl | 10 +++++----- test/test_util.jl | 36 +++++++----------------------------- 5 files changed, 25 insertions(+), 49 deletions(-) diff --git a/Project.toml b/Project.toml index 8a7a571f2..206197077 100644 --- a/Project.toml +++ b/Project.toml @@ -10,8 +10,8 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.4" -FiniteDifferences = "^0.7" +ChainRulesCore = "0.4, 0.5" +FiniteDifferences = "^0.7, 0.8, 0.9" Reexport = "0.2" Requires = "0.5.2, 1" julia = "^1.0" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 39e2ab1e7..433ac8568 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -2,7 +2,6 @@ @scalar_rule(zero(x), Zero()) @scalar_rule(sign(x), Zero()) -@scalar_rule(abs2(x), Wirtinger(x', x)) @scalar_rule(log(x), inv(x)) @scalar_rule(log10(x), inv(x) / log(oftype(x, 10))) @scalar_rule(log2(x), inv(x) / log(oftype(x, 2))) @@ -50,14 +49,13 @@ @scalar_rule(deg2rad(x), π / oftype(x, 180)) @scalar_rule(rad2deg(x), oftype(x, 180) / π) -@scalar_rule(conj(x), Wirtinger(Zero(), One())) -@scalar_rule(adjoint(x), Wirtinger(Zero(), One())) +@scalar_rule(conj(x), One()) +@scalar_rule(adjoint(x), One()) @scalar_rule(transpose(x), One()) @scalar_rule(abs(x::Real), sign(x)) -@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω)) +@scalar_rule(abs2(x), 2sign(x)) @scalar_rule(hypot(x::Real), sign(x)) -@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω)) @scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist())) @scalar_rule(+(x), One()) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 979cc9925..ee915c086 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -51,7 +51,7 @@ test_scalar(acscd, 1/x) test_scalar(acotd, 1/x) end - + @testset "sincos" begin x, Δx, x̄ = randn(3) Δz = (randn(), randn()) @@ -60,7 +60,7 @@ rrule_test(sincos, Δz, (x, x̄)) end end # Trig - +#== @testset "math" begin for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im) test_scalar(deg2rad, x) @@ -76,12 +76,11 @@ if (x isa Real && x >= 0) || x isa Complex # this check is needed because these have discontinuities between # `-10 + im*eps()` and `-10 - im*eps()` - should_test_wirtinger = imag(x) != 0 && real(x) < 0 - test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger) - test_scalar(log, x; test_wirtinger=should_test_wirtinger) - test_scalar(log2, x; test_wirtinger=should_test_wirtinger) - test_scalar(log10, x; test_wirtinger=should_test_wirtinger) - test_scalar(log1p, x; test_wirtinger=should_test_wirtinger) + test_scalar(sqrt, x) + test_scalar(log, x) + test_scalar(log2, x) + test_scalar(log10, x) + test_scalar(log1p, x) end end end @@ -168,4 +167,5 @@ frule_test(f, (x, Δx), (y, Δy), (z, Δz)) rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄)) end +==# end diff --git a/test/runtests.jl b/test/runtests.jl index aa471c266..ba1c81959 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,6 @@ using Test # For testing purposes we use a lot of using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, - Wirtinger, wirtinger_primal, wirtinger_conjugate, Zero, One, DoesNotExist, Thunk, AbstractDifferential Random.seed!(1) # Set seed that all testsets should reset to. @@ -25,11 +24,11 @@ println("Testing ChainRules.jl") @testset "Base" begin include(joinpath("rulesets", "Base", "base.jl")) - include(joinpath("rulesets", "Base", "array.jl")) - include(joinpath("rulesets", "Base", "mapreduce.jl")) - include(joinpath("rulesets", "Base", "broadcast.jl")) + #include(joinpath("rulesets", "Base", "array.jl")) + #include(joinpath("rulesets", "Base", "mapreduce.jl")) + #include(joinpath("rulesets", "Base", "broadcast.jl")) end - + #== print(" ") @testset "Statistics" begin @@ -51,5 +50,6 @@ println("Testing ChainRules.jl") include(joinpath("rulesets", "packages", "NaNMath.jl")) include(joinpath("rulesets", "packages", "SpecialFunctions.jl")) end + ==# end end diff --git a/test/test_util.jl b/test/test_util.jl index 07d8e8277..b1d2e0b16 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -7,7 +7,7 @@ const _fdm = central_fdm(5, 1) """ - test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), test_wirtinger=x isa Complex, kwargs...) + test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) Given a function `f` with scalar input an scalar output, perform finite differencing checks, at input point `x` to confirm that there are correct ChainRules provided. @@ -16,11 +16,9 @@ at input point `x` to confirm that there are correct ChainRules provided. - `f`: Function for which the `frule` and `rrule` should be tested. - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). -- `test_wirtinger`: test whether the wirtinger derivative is correct, too - -All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`. +All keyword arguments except for `fdm` are passed to `isapprox`. """ -function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...) +function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) ensure_not_running_on_functor(f, "test_scalar") @testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule) @@ -40,26 +38,10 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa # Check that we get the derivative right: - if !test_wirtinger - @test isapprox( - ∂x, fdm(f, x); - rtol=rtol, atol=atol, kwargs... - ) - else - # For complex arguments, also check if the wirtinger derivative is correct - ∂Re = fdm(ϵ -> f(x + ϵ), 0) - ∂Im = fdm(ϵ -> f(x + im*ϵ), 0) - ∂ = 0.5(∂Re - im*∂Im) - ∂̅ = 0.5(∂Re + im*∂Im) - @test isapprox( - wirtinger_primal(∂x), ∂; - rtol=rtol, atol=atol, kwargs... - ) - @test isapprox( - wirtinger_conjugate(∂x), ∂̅; - rtol=rtol, atol=atol, kwargs... - ) - end + @test isapprox( + ∂x, fdm(f, x); + rtol=rtol, atol=atol, kwargs... + ) end end @@ -204,10 +186,6 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm end end -function Base.isapprox(ad::Wirtinger, fd; kwargs...) - error("Finite differencing with Wirtinger rules not implemented") -end - function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) error("Tried to differentiate w.r.t. a `DoesNotExist`") end From 2ad6762f1cc6d312d1d73c14d55fde1bfcf176c6 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 10 Jan 2020 19:51:30 +0000 Subject: [PATCH 2/2] keep locked down to FiniteDifferences v0.7 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 206197077..6761d23c0 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.4, 0.5" -FiniteDifferences = "^0.7, 0.8, 0.9" +FiniteDifferences = "^0.7" Reexport = "0.2" Requires = "0.5.2, 1" julia = "^1.0"