Skip to content

Commit 0d49880

Browse files
dextoriousChrisRackauckas
authored andcommitted
Gradients of f:C->C^n maps and associated tests.
1 parent baef649 commit 0d49880

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

src/gradients.jl

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
120120
if fdtype == Val{:forward}
121121
@inbounds for i eachindex(x)
122122
c2[i] += c1[i]
123-
df[i] = (f(c2) - f(x)) / c1[i]
123+
if typeof(fx) != Void
124+
df[i] = (f(c2) - fx) / c1[i]
125+
else
126+
df[i] = (f(c2) - f(x)) / c1[i]
127+
end
124128
c2[i] -= c1[i]
125129
end
126130
elseif fdtype == Val{:central}
@@ -177,14 +181,63 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
177181
df
178182
end
179183

180-
184+
# vector of derivatives of f : C^n -> C by each component of a vector x
181185
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
182186
cache::GradientCache{T1,T2,T3,fdtype,Val{:Complex}}) where {T1,T2,T3,fdtype}
183187

184-
error("Not implemented yet.")
188+
# NOTE: in this case epsilon is a vector, we need two arrays for epsilon and x1
189+
# c1 denotes epsilon (pre-computed by the cache constructor),
190+
# c2 is x1, pre-set to the values of x by the cache constructor
191+
fx, c1, c2 = cache.fx, cache.c1, cache.c2
192+
if fdtype == Val{:forward}
193+
@inbounds for i eachindex(x)
194+
epsilon = c1[i]
195+
c2[i] += epsilon
196+
if typeof(fx) == Void
197+
df[i] = real(f(c2) - f(x)) / epsilon
198+
else
199+
df[i] = real(f(c2) - fx) / epsilon
200+
end
201+
c2[i] -= epsilon
202+
c2[i] += im*epsilon
203+
if typeof(fx) == Void
204+
df[i] += im*imag(f(c2) - f(x)) / epsilon
205+
else
206+
df[i] += im*imag(f(c2) - fx) / epsilon
207+
end
208+
c2[i] -= im*epsilon
209+
end
210+
elseif fdtype == Val{:central}
211+
@inbounds for i eachindex(x)
212+
epsilon = c1[i]
213+
c2[i] += epsilon
214+
x[i] -= epsilon
215+
df[i] = real(f(c2) - f(x)) / (2*epsilon)
216+
c2[i] -= c1[i]
217+
x[i] += c1[i]
218+
c2[i] += im*epsilon
219+
x[i] -= im*epsilon
220+
df[i] += im*imag(f(c2) - f(x)) / (2*epsilon)
221+
c2[i] -= im*epsilon
222+
x[i] += im*epsilon
223+
end
224+
elseif fdtype == Val{:complex}
225+
epsilon_elemtype = compute_epsilon_elemtype(nothing, x)
226+
epsilon_complex = eps(epsilon_elemtype)
227+
# we use c1 here to avoid typing issues with x
228+
@inbounds for i eachindex(x)
229+
c1[i] += im*epsilon_complex
230+
df[i] = imag(f(c1)) / epsilon_complex
231+
c1[i] -= im*epsilon_complex
232+
end
233+
else
234+
fdtype_error(Val{:Complex})
235+
end
185236
df
186237
end
187238

239+
# vector of derivatives of f : C -> C^n
240+
# this is effectively a vector of partial derivatives, but we still call it a gradient
188241
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
189242
cache::GradientCache{T1,T2,T3,fdtype,Val{:Complex}}) where {T1,T2,T3,fdtype}
190243

test/finitedifftests.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,26 @@ complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:compl
112112
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, complex_cache), df_ref) < 1e-15
113113
end
114114

115+
f(x) = 2x[1] + im*2x[1] + x[2]^2
116+
x = x + im*x
117+
fx = f(x)
118+
df = zeros(x)
119+
df_ref = [2.0+2.0*im, 2.0*x[2]]
120+
DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward})
121+
DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central})
122+
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward},Val{:Real})
123+
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central},Val{:Real})
124+
complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:complex},Val{:Real})
125+
115126
@time @testset "Gradient of f:R^n->R complex-valued tests" begin
116-
# TODO
127+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
128+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}), df_ref) < 1e-8
129+
130+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}), df_ref) < 1e-4
131+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}), df_ref) < 1e-8
132+
133+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, forward_cache), df_ref) < 1e-4
134+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, central_cache), df_ref) < 1e-8
117135
end
118136

119137
f(x) = [sin(x), cos(x)]

0 commit comments

Comments
 (0)