Skip to content

Commit 8b3e69c

Browse files
committed
test wirtinger derivatives with fdm too
1 parent 9230d28 commit 8b3e69c

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

test/rulesets/Base/base.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,13 @@
8787

8888
x isa Real && test_scalar(cbrt, x)
8989
if (x isa Real && x >= 0) || x isa Complex
90-
test_scalar(sqrt, x)
91-
test_scalar(log, x)
92-
test_scalar(log2, x)
93-
test_scalar(log10, x)
94-
test_scalar(log1p, x)
90+
# this check is needed because the imaginary part jumps from pi to -pi
91+
nr = imag(x) != 0 && real(x) < 0
92+
test_scalar(sqrt, x, test_wirtinger=nr)
93+
test_scalar(log, x, test_wirtinger=nr)
94+
test_scalar(log2, x, test_wirtinger=nr)
95+
test_scalar(log10, x, test_wirtinger=nr)
96+
test_scalar(log1p, x, test_wirtinger=nr)
9597
end
9698
end
9799
end

test/test_util.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,33 @@ at input point `x` to confirm that there are correct ChainRules provided.
1616
1717
All keyword arguments except for `fdm` are passed to `isapprox`.
1818
"""
19-
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
19+
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...)
2020
@testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule)
2121
res = rule(f, x)
2222
@test res !== nothing # Check the rule was defined
2323
fx, ∂x = res
2424
@test fx == f(x) # Check we still get the normal value, right
2525

2626
# Check that we get the derivative right:
27-
@test isapprox(
28-
∂x(1), fdm(f, x);
29-
rtol=rtol, atol=atol, kwargs...
30-
)
27+
if !test_wirtinger
28+
@test isapprox(
29+
∂x(1), fdm(f, x);
30+
rtol=rtol, atol=atol, kwargs...
31+
)
32+
else
33+
# For complex arguments, also check if the wirtinger derivative is correct
34+
∂Re, ∂Im = fdm.((ϵ -> f(x + ϵ), ϵ -> f(x + im*ϵ)), 0)
35+
= .5(∂Re - im*∂Im)
36+
∂̅ = .5(∂Re + im*∂Im)
37+
@test isapprox(
38+
wirtinger_primal(∂x(1)), ∂;
39+
rtol=rtol, atol=atol, kwargs...
40+
)
41+
@test isapprox(
42+
extern(wirtinger_conjugate(∂x(1))), ∂̅;
43+
rtol=rtol, atol=atol, kwargs...
44+
)
45+
end
3146
end
3247
end
3348

0 commit comments

Comments
 (0)