Skip to content

Commit bdfb7e2

Browse files
dextoriousChrisRackauckas
authored andcommitted
Preliminary fix for gradients of f:C^N->R.
1 parent 3d3bb84 commit bdfb7e2

File tree

2 files changed

+140
-31
lines changed

2 files changed

+140
-31
lines changed

src/gradients.jl

Lines changed: 118 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,25 @@ function finite_difference_gradient(f,x,
130130
df
131131
end
132132

133+
#=
134+
function f1(df,f,x,epsilon)
135+
for i in eachindex(x)
136+
x0=x[i]
137+
x[i]+=epsilon
138+
dfi=f(x)
139+
x[i]=x0
140+
dfi-=f(x)
141+
df[i]=real(dfi/epsilon)
142+
x[i]+=im*epsilon
143+
dfi=f(x)
144+
x[i]=x0
145+
dfi-=f(x)
146+
df[i]+=im*imag(dfi/(im*epsilon))
147+
end
148+
df
149+
end
150+
=#
151+
133152
# vector of derivatives of a vector->scalar map by each component of a vector x
134153
# this ignores the value of "inplace", because it doesn't make much sense
135154
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
@@ -144,16 +163,40 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
144163
copy!(c1,x)
145164
end
146165
if fdtype == Val{:forward}
147-
@inbounds for i eachindex(x)
148-
epsilon = c2[i]
149-
c1_old = c1[i]
150-
c1[i] += epsilon
151-
if typeof(fx) != Void
152-
df[i] = (f(c1) - fx) / epsilon
153-
else
154-
df[i] = (f(c1) - f(x)) / epsilon
166+
if eltype(df)<:Complex || returntype<:Complex || eltype(x)<:Complex
167+
for i eachindex(x)
168+
epsilon = c2[i]
169+
c1_old = c1[i]
170+
c1[i] += epsilon
171+
if typeof(fx) != Void
172+
dfi = (f(c1) - fx) / epsilon
173+
else
174+
fx0 = f(x)
175+
dfi = (f(c1) - fx0) / epsilon
176+
end
177+
df[i] = real(dfi)
178+
c1[i] = c1_old
179+
c1[i] += im * epsilon
180+
if typeof(fx) != Void
181+
dfi = (f(c1) - fx) / (im*epsilon)
182+
else
183+
dfi = (f(c1) - fx0) / (im*epsilon)
184+
end
185+
c1[i] = c1_old
186+
df[i] += im * imag(dfi)
187+
end
188+
else
189+
@inbounds for i eachindex(x)
190+
epsilon = c2[i]
191+
c1_old = c1[i]
192+
c1[i] += epsilon
193+
if typeof(fx) != Void
194+
df[i] = (f(c1) - fx) / epsilon
195+
else
196+
df[i] = (f(c1) - f(x)) / epsilon
197+
end
198+
c1[i] = c1_old
155199
end
156-
c1[i] = c1_old
157200
end
158201
elseif fdtype == Val{:central}
159202
@inbounds for i eachindex(x)
@@ -191,31 +234,76 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
191234
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
192235
end
193236
if fdtype == Val{:forward}
194-
@inbounds for i eachindex(x)
195-
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
196-
x_old = x[i]
197-
x[i] += epsilon
198-
dfi = f(x)
199-
x[i] = x_old
200-
if typeof(fx) != Void
201-
dfi -= fx
202-
else
203-
dfi -= f(x)
237+
if eltype(df)<:Complex || returntype<:Complex || eltype(x)<:Complex
238+
for i eachindex(x)
239+
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
240+
x_old = x[i]
241+
if typeof(fx) != Void
242+
x[i] += epsilon
243+
dfi = (f(x) - fx) / epsilon
244+
x[i] = x_old
245+
else
246+
fx0 = f(x)
247+
x[i] += epsilon
248+
dfi = (f(x) - fx0) / epsilon
249+
x[i] = x_old
250+
end
251+
df[i] = real(dfi)
252+
x[i] += im * epsilon
253+
if typeof(fx) != Void
254+
dfi = (f(x) - fx) / (im*epsilon)
255+
else
256+
dfi = (f(x) - fx0) / (im*epsilon)
257+
end
258+
x[i] = x_old
259+
df[i] += im * imag(dfi)
260+
end
261+
else
262+
@inbounds for i eachindex(x)
263+
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
264+
x_old = x[i]
265+
x[i] += epsilon
266+
dfi = f(x)
267+
x[i] = x_old
268+
if typeof(fx) != Void
269+
dfi -= fx
270+
else
271+
dfi -= f(x)
272+
end
273+
df[i] = dfi / epsilon
204274
end
205-
df[i] = dfi / epsilon
206275
end
207276
elseif fdtype == Val{:central}
208-
@inbounds for i eachindex(x)
209-
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
210-
x_old = x[i]
211-
x[i] += epsilon
212-
dfi = f(x)
213-
x[i] = x_old - epsilon
214-
dfi -= f(x)
215-
x[i] = x_old
216-
df[i] = dfi / (2*epsilon)
277+
if eltype(df)<:Complex || returntype<:Complex || eltype(x)<:Complex
278+
@inbounds for i eachindex(x)
279+
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
280+
x_old = x[i]
281+
x[i] += epsilon
282+
dfi = f(x)
283+
x[i] = x_old - epsilon
284+
dfi -= f(x)
285+
x[i] = x_old
286+
df[i] = real(dfi / (2*epsilon))
287+
x[i] += im*epsilon
288+
dfi = f(x)
289+
x[i] = x_old - im*epsilon
290+
dfi -= f(x)
291+
x[i] = x_old
292+
df[i] += im*imag(dfi / (2*im*epsilon))
293+
end
294+
else
295+
@inbounds for i eachindex(x)
296+
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
297+
x_old = x[i]
298+
x[i] += epsilon
299+
dfi = f(x)
300+
x[i] = x_old - epsilon
301+
dfi -= f(x)
302+
x[i] = x_old
303+
df[i] = dfi / (2*epsilon)
304+
end
217305
end
218-
elseif fdtype == Val{:complex} && returntype <: Real
306+
elseif fdtype==Val{:complex} && returntype<:Real && eltype(df)<:Real && eltype(x)<:Real
219307
copy!(c1,x)
220308
epsilon_complex = eps(real(eltype(x)))
221309
# we use c1 here to avoid typing issues with x

test/finitedifftests.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,28 @@ DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central})
113113
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
114114
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
115115

116-
@time @testset "Gradient of f:vector->scalar complex-valued tests" begin
116+
@time @testset "Gradient of f : C^N -> C tests" begin
117+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
118+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}), df_ref) < 1e-8
119+
120+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}), df_ref) < 1e-4
121+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}), df_ref) < 1e-8
122+
123+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, forward_cache), df_ref) < 1e-4
124+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, central_cache), df_ref) < 1e-8
125+
end
126+
127+
f(x) = sum(abs2, x)
128+
x = ones(2) * (1 + im)
129+
fx = f(x)
130+
df = zeros(x)
131+
df_ref = 2*conj.(x)
132+
DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward})
133+
DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central})
134+
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
135+
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
136+
137+
@time @testset "Gradient of f : C^N -> R tests" begin
117138
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
118139
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}), df_ref) < 1e-8
119140

0 commit comments

Comments
 (0)