Skip to content

Commit 0f805e8

Browse files
dextoriousChrisRackauckas
authored andcommitted
Complex-mode f:R^n->R gradient fixes and tests.
1 parent fbe6dbc commit 0f805e8

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

src/gradients.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,26 @@ function GradientCache(
2525

2626
if typeof(x) <: AbstractArray # the f:R^n->R case
2727
# need cache arrays for epsilon (c1) and x1 (c2)
28-
epsilon_elemtype = compute_epsilon_elemtype(nothing, x)
29-
if typeof(c1) == Void || eltype(c1) != epsilon_elemtype
30-
_c1 = zeros(epsilon_elemtype, size(x))
28+
if fdtype != Val{:complex} # complex-mode FD only needs one cache, for x+eps*im
29+
epsilon_elemtype = compute_epsilon_elemtype(nothing, x)
30+
if typeof(c1) == Void || eltype(c1) != epsilon_elemtype
31+
_c1 = zeros(epsilon_elemtype, size(x))
32+
else
33+
_c1 = c1
34+
end
35+
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
36+
@. _c1 = compute_epsilon(fdtype, real(x), epsilon_factor)
37+
38+
if typeof(c2) != typeof(x) || size(c2) != size(x)
39+
_c2 = copy(x)
40+
else
41+
copy!(_c2, x)
42+
end
3143
else
32-
_c1 = c1
44+
_c1 = x + 0*im
45+
_c2 = nothing
3346
end
34-
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
35-
@. _c1 = compute_epsilon(fdtype, real(x), epsilon_factor)
3647

37-
if typeof(c2) != typeof(x) || size(c2) != size(x)
38-
_c2 = copy(x)
39-
else
40-
copy!(_c2, x)
41-
end
4248
else # the f:R->R^n case
4349
# need cache arrays for fx1 and fx2
4450
if typeof(c1) != typeof(df) || size(c1) != size(df)
@@ -117,7 +123,14 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
117123
x[i] += c1[i]
118124
end
119125
elseif fdtype == Val{:complex}
120-
# TODO
126+
epsilon_elemtype = compute_epsilon_elemtype(nothing, x)
127+
epsilon_complex = eps(epsilon_elemtype)
128+
# we use c1 here to avoid typing issues with x
129+
@inbounds for i eachindex(x)
130+
c1[i] += im*epsilon_complex
131+
df[i] = imag(f(c1)) / epsilon_complex
132+
c1[i] -= im*epsilon_complex
133+
end
121134
end
122135
df
123136
end

test/finitedifftests.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
2+
# TODO: add tests for GPUArrays
3+
# TODO: add tests for DEDataArrays
4+
5+
6+
# Derivative tests
17
x = collect(linspace(-2π, 2π, 100))
28
y = sin.(x)
39
df = zeros(100)
@@ -9,9 +15,6 @@ complex_cache = DiffEqDiffTools.DerivativeCache(x, nothing, nothing, Val{:comple
915

1016
err_func(a,b) = maximum(abs.(a-b))
1117

12-
# TODO: add tests for GPUArrays
13-
# TODO: add tests for DEDataArrays
14-
1518
@time @testset "Derivative single point real-valued tests" begin
1619
@test err_func(DiffEqDiffTools.finite_difference_derivative(sin, π/4, Val{:forward}), 2/2) < 1e-4
1720
@test err_func(DiffEqDiffTools.finite_difference_derivative(sin, π/4, Val{:central}), 2/2) < 1e-8
@@ -85,30 +88,43 @@ end
8588
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, central_cache), df_ref) < 1e-8
8689
end
8790

91+
# Gradient tests
8892
f(x) = 2x[1] + x[2]^2
8993
x = rand(2)
9094
fx = f(x)
9195
df = zeros(2)
9296
df_ref = [2., 2*x[2]]
9397
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward},Val{:Real})
9498
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central},Val{:Real})
95-
#complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:complex},Val{:Real})
99+
complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:complex},Val{:Real})
96100

97101
@time @testset "Gradient of f:R^n->R real-valued tests" begin
98102
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
99103
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}), df_ref) < 1e-8
100-
#@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:complex}), df_ref) < 1e-15
104+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:complex}), df_ref) < 1e-15
101105

102106
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}), df_ref) < 1e-4
103107
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}), df_ref) < 1e-8
104-
#@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:complex}), df_ref) < 1e-15
108+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:complex}), df_ref) < 1e-15
105109

106110
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, forward_cache), df_ref) < 1e-4
107111
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, central_cache), df_ref) < 1e-8
108-
#@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, complex_cache), df_ref) < 1e-15
112+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, complex_cache), df_ref) < 1e-15
113+
end
114+
115+
@time @testset "Gradient of f:R^n->R complex-valued tests" begin
116+
# TODO
109117
end
110118

119+
@time @testset "Gradient of f:R->R^n real-valued tests" begin
120+
# TODO
121+
end
122+
123+
@time @testset "Gradient of f:R->R^n complex-valued tests" begin
124+
# TODO
125+
end
111126

127+
# Jacobian tests
112128
function f(fvec,x)
113129
fvec[1] = (x[1]+3)*(x[2]^3-7)+18
114130
fvec[2] = sin(x[2]*exp(x[1])-1)

0 commit comments

Comments
 (0)