1
1
using FiniteDifferences, Test
2
2
using FiniteDifferences: jvp, j′vp
3
3
using ChainRules
4
+ using ChainRulesCore: AbstractDifferential
4
5
5
6
const _fdm = central_fdm (5 , 1 )
6
7
7
8
"""
8
- test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
9
+ test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), test_wirtinger=x isa Complex, kwargs...)
9
10
10
11
Given a function `f` with scalar input an scalar output, perform finite differencing checks,
11
12
at input point `x` to confirm that there are correct ChainRules provided.
@@ -14,20 +15,38 @@ at input point `x` to confirm that there are correct ChainRules provided.
14
15
- `f`: Function for which the `frule` and `rrule` should be tested.
15
16
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
16
17
17
- All keyword arguments except for `fdm` are passed to `isapprox`.
18
+ - `test_wirtinger`: test whether the wirtinger derivative is correct, too
19
+
20
+ All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`.
18
21
"""
19
- function test_scalar (f, x; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, kwargs... )
22
+ function test_scalar (f, x; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, test_wirtinger = x isa Complex, kwargs... )
20
23
@testset " $f at $x , $(nameof (rule)) " for rule in (rrule, frule)
21
24
res = rule (f, x)
22
25
@test res != = nothing # Check the rule was defined
23
26
fx, ∂x = res
24
27
@test fx == f (x) # Check we still get the normal value, right
25
28
26
29
# Check that we get the derivative right:
27
- @test isapprox (
28
- ∂x (1 ), fdm (f, x);
29
- rtol= rtol, atol= atol, kwargs...
30
- )
30
+ if ! test_wirtinger
31
+ @test isapprox (
32
+ ∂x (1 ), fdm (f, x);
33
+ rtol= rtol, atol= atol, kwargs...
34
+ )
35
+ else
36
+ # For complex arguments, also check if the wirtinger derivative is correct
37
+ ∂Re = fdm (ϵ -> f (x + ϵ), 0 )
38
+ ∂Im = fdm (ϵ -> f (x + im* ϵ), 0 )
39
+ ∂ = 0.5 (∂Re - im* ∂Im)
40
+ ∂̅ = 0.5 (∂Re + im* ∂Im)
41
+ @test isapprox (
42
+ wirtinger_primal (∂x (1 )), ∂;
43
+ rtol= rtol, atol= atol, kwargs...
44
+ )
45
+ @test isapprox (
46
+ wirtinger_conjugate (∂x (1 )), ∂̅;
47
+ rtol= rtol, atol= atol, kwargs...
48
+ )
49
+ end
31
50
end
32
51
end
33
52
144
163
function Base. isapprox (d_ad:: DNE , d_fd; kwargs... )
145
164
error (" Tried to differentiate w.r.t. a DNE" )
146
165
end
147
- function Base. isapprox (d_ad:: Thunk , d_fd; kwargs... )
166
+ function Base. isapprox (d_ad:: AbstractDifferential , d_fd; kwargs... )
148
167
return isapprox (extern (d_ad), d_fd; kwargs... )
149
168
end
150
169
0 commit comments