@@ -16,18 +16,33 @@ at input point `x` to confirm that there are correct ChainRules provided.
16
16
17
17
All keyword arguments except for `fdm` are passed to `isapprox`.
18
18
"""
19
- function test_scalar (f, x; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, kwargs... )
19
+ function test_scalar (f, x; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, test_wirtinger = x isa Complex, kwargs... )
20
20
@testset " $f at $x , $(nameof (rule)) " for rule in (rrule, frule)
21
21
res = rule (f, x)
22
22
@test res != = nothing # Check the rule was defined
23
23
fx, ∂x = res
24
24
@test fx == f (x) # Check we still get the normal value, right
25
25
26
26
# Check that we get the derivative right:
27
- @test isapprox (
28
- ∂x (1 ), fdm (f, x);
29
- rtol= rtol, atol= atol, kwargs...
30
- )
27
+ if ! test_wirtinger
28
+ @test isapprox (
29
+ ∂x (1 ), fdm (f, x);
30
+ rtol= rtol, atol= atol, kwargs...
31
+ )
32
+ else
33
+ # For complex arguments, also check if the wirtinger derivative is correct
34
+ ∂Re, ∂Im = fdm .((ϵ -> f (x + ϵ), ϵ -> f (x + im* ϵ)), 0 )
35
+ ∂ = .5 (∂Re - im* ∂Im)
36
+ ∂̅ = .5 (∂Re + im* ∂Im)
37
+ @test isapprox (
38
+ wirtinger_primal (∂x (1 )), ∂;
39
+ rtol= rtol, atol= atol, kwargs...
40
+ )
41
+ @test isapprox (
42
+ extern (wirtinger_conjugate (∂x (1 ))), ∂̅;
43
+ rtol= rtol, atol= atol, kwargs...
44
+ )
45
+ end
31
46
end
32
47
end
33
48
0 commit comments