Skip to content

Commit baef649

Browse files
dextoriousChrisRackauckas
authored andcommitted
Real-valued gradients of f:R->R^n and tests.
1 parent 0f805e8 commit baef649

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

src/gradients.jl

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,25 @@ function GradientCache(
4646
end
4747

4848
else # the f:R->R^n case
49-
# need cache arrays for fx1 and fx2
50-
if typeof(c1) != typeof(df) || size(c1) != size(df)
51-
_c1 = similar(df)
52-
else
53-
_c1 = c1
54-
end
55-
if typeof(c2) != typeof(df) || size(c2) != size(df)
56-
_c2 = similar(df)
49+
# need cache arrays for fx1 and fx2, except in complex mode
50+
if fdtype != Val{:complex}
51+
if typeof(c1) != typeof(df) || size(c1) != size(df)
52+
_c1 = similar(df)
53+
else
54+
_c1 = c1
55+
end
56+
if fdtype == Val{:forward} && typeof(fx) != Void
57+
_c2 = nothing
58+
else
59+
if typeof(c2) != typeof(df) || size(c2) != size(df)
60+
_c2 = similar(df)
61+
else
62+
_c2 = c2
63+
end
64+
end
5765
else
58-
_c2 = c2
66+
_c1 = nothing
67+
_c2 = nothing
5968
end
6069
end
6170
GradientCache{typeof(_fx),typeof(_c1),typeof(_c2),fdtype,RealOrComplex}(_fx,_c1,_c2)
@@ -144,16 +153,26 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
144153
# c1 denotes fx1, c2 is fx2, sizes guaranteed by the cache constructor
145154
fx, c1, c2 = cache.fx, cache.c1, cache.c2
146155

156+
epsilon_elemtype = compute_epsilon_elemtype(nothing, x)
147157
if fdtype == Val{:forward}
148-
# TODO
158+
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
159+
epsilon = compute_epsilon(Val{:forward}, real(x), epsilon_factor)
160+
c1 .= f(x+epsilon)
161+
if typeof(fx) != Void
162+
@. df = (c1 - fx) / epsilon
163+
else
164+
c2 .= f(x)
165+
@. df = (c1 - c2) / epsilon
166+
end
149167
elseif fdtype == Val{:central}
168+
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
169+
epsilon = compute_epsilon(Val{:central}, real(x), epsilon_factor)
150170
c1 .= f(x+epsilon)
151171
c2 .= f(x-epsilon)
152-
@inbounds for i 1 : length(fx)
153-
df[i] = (f(x+epsilon)[1] - f(x-epsilon)[1]) / (2*epsilon)
154-
end
172+
@. df = (c1 - c2) / (2*epsilon)
155173
elseif fdtype == Val{:complex}
156-
# TODO
174+
epsilon_complex = eps(epsilon_elemtype)
175+
df .= imag.(f(x+im*epsilon_complex)) ./ epsilon_complex
157176
end
158177
df
159178
end
@@ -162,13 +181,13 @@ end
162181
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
163182
cache::GradientCache{T1,T2,T3,fdtype,Val{:Complex}}) where {T1,T2,T3,fdtype}
164183

165-
# TODO
184+
error("Not implemented yet.")
166185
df
167186
end
168187

169188
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
170189
cache::GradientCache{T1,T2,T3,fdtype,Val{:Complex}}) where {T1,T2,T3,fdtype}
171190

172-
# TODO
191+
error("Not implemented yet.")
173192
df
174193
end

test/finitedifftests.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,27 @@ end
116116
# TODO
117117
end
118118

119+
f(x) = [sin(x), cos(x)]
120+
x = 2π * rand()
121+
fx = f(x)
122+
df = zeros(2)
123+
df_ref = [cos(x), -sin(x)]
124+
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward},Val{:Real})
125+
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central},Val{:Real})
126+
complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:complex},Val{:Real})
127+
119128
@time @testset "Gradient of f:R->R^n real-valued tests" begin
120-
# TODO
129+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
130+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}), df_ref) < 1e-8
131+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:complex}), df_ref) < 1e-15
132+
133+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}), df_ref) < 1e-4
134+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}), df_ref) < 1e-8
135+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:complex}), df_ref) < 1e-15
136+
137+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, forward_cache), df_ref) < 1e-4
138+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, central_cache), df_ref) < 1e-8
139+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, complex_cache), df_ref) < 1e-15
121140
end
122141

123142
@time @testset "Gradient of f:R->R^n complex-valued tests" begin

0 commit comments

Comments
 (0)