Skip to content

Commit e8648ff

Browse files
authored
Merge pull request JuliaDiff#90 from simeonschaub/ox/testmore
test wirtinger derivatives with fdm too
2 parents c6e84c1 + 6dda892 commit e8648ff

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

test/rulesets/Base/base.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,14 @@
8787

8888
x isa Real && test_scalar(cbrt, x)
8989
if (x isa Real && x >= 0) || x isa Complex
90-
test_scalar(sqrt, x)
91-
test_scalar(log, x)
92-
test_scalar(log2, x)
93-
test_scalar(log10, x)
94-
test_scalar(log1p, x)
90+
# this check is needed because these have discontinuities between
91+
# `-10 + im*eps()` and `-10 - im*eps()`
92+
should_test_wirtinger = imag(x) != 0 && real(x) < 0
93+
test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger)
94+
test_scalar(log, x; test_wirtinger=should_test_wirtinger)
95+
test_scalar(log2, x; test_wirtinger=should_test_wirtinger)
96+
test_scalar(log10, x; test_wirtinger=should_test_wirtinger)
97+
test_scalar(log1p, x; test_wirtinger=should_test_wirtinger)
9598
end
9699
end
97100
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,23 @@ println("Testing ChainRules.jl")
2828
include(joinpath("rulesets", "Base", "broadcast.jl"))
2929
end
3030

31+
print(" ")
32+
3133
@testset "Statistics" begin
3234
include(joinpath("rulesets", "Statistics", "statistics.jl"))
3335
end
3436

37+
print(" ")
38+
3539
@testset "LinearAlgebra" begin
3640
include(joinpath("rulesets", "LinearAlgebra", "dense.jl"))
3741
include(joinpath("rulesets", "LinearAlgebra", "structured.jl"))
3842
include(joinpath("rulesets", "LinearAlgebra", "factorization.jl"))
3943
include(joinpath("rulesets", "LinearAlgebra", "blas.jl"))
4044
end
4145

46+
print(" ")
47+
4248
@testset "packages" begin
4349
include(joinpath("rulesets", "packages", "NaNMath.jl"))
4450
include(joinpath("rulesets", "packages", "SpecialFunctions.jl"))

test/test_util.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
using FiniteDifferences, Test
22
using FiniteDifferences: jvp, j′vp
33
using ChainRules
4+
using ChainRulesCore: AbstractDifferential
45

56
const _fdm = central_fdm(5, 1)
67

78
"""
8-
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
9+
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), test_wirtinger=x isa Complex, kwargs...)
910
1011
Given a function `f` with scalar input an scalar output, perform finite differencing checks,
1112
at input point `x` to confirm that there are correct ChainRules provided.
@@ -14,20 +15,38 @@ at input point `x` to confirm that there are correct ChainRules provided.
1415
- `f`: Function for which the `frule` and `rrule` should be tested.
1516
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
1617
17-
All keyword arguments except for `fdm` are passed to `isapprox`.
18+
- `test_wirtinger`: test whether the wirtinger derivative is correct, too
19+
20+
All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`.
1821
"""
19-
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
22+
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...)
2023
@testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule)
2124
res = rule(f, x)
2225
@test res !== nothing # Check the rule was defined
2326
fx, ∂x = res
2427
@test fx == f(x) # Check we still get the normal value, right
2528

2629
# Check that we get the derivative right:
27-
@test isapprox(
28-
∂x(1), fdm(f, x);
29-
rtol=rtol, atol=atol, kwargs...
30-
)
30+
if !test_wirtinger
31+
@test isapprox(
32+
∂x(1), fdm(f, x);
33+
rtol=rtol, atol=atol, kwargs...
34+
)
35+
else
36+
# For complex arguments, also check if the wirtinger derivative is correct
37+
∂Re = fdm-> f(x + ϵ), 0)
38+
∂Im = fdm-> f(x + im*ϵ), 0)
39+
= 0.5(∂Re - im*∂Im)
40+
∂̅ = 0.5(∂Re + im*∂Im)
41+
@test isapprox(
42+
wirtinger_primal(∂x(1)), ∂;
43+
rtol=rtol, atol=atol, kwargs...
44+
)
45+
@test isapprox(
46+
wirtinger_conjugate(∂x(1)), ∂̅;
47+
rtol=rtol, atol=atol, kwargs...
48+
)
49+
end
3150
end
3251
end
3352

@@ -144,7 +163,7 @@ end
144163
function Base.isapprox(d_ad::DNE, d_fd; kwargs...)
145164
error("Tried to differentiate w.r.t. a DNE")
146165
end
147-
function Base.isapprox(d_ad::Thunk, d_fd; kwargs...)
166+
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
148167
return isapprox(extern(d_ad), d_fd; kwargs...)
149168
end
150169

0 commit comments

Comments
 (0)