Skip to content

Commit 45728b2

Browse files
dextoriousChrisRackauckas
authored andcommitted
Fix gradients of f:R^N->C.
1 parent bdfb7e2 commit 45728b2

File tree

2 files changed

+108
-111
lines changed

2 files changed

+108
-111
lines changed

src/gradients.jl

Lines changed: 88 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function GradientCache(
1111
c1 :: Union{Void,AbstractArray{<:Number}} = nothing,
1212
c2 :: Union{Void,AbstractArray{<:Number}} = nothing,
1313
fdtype :: Type{T1} = Val{:central},
14-
returntype :: Type{T2} = eltype(x),
14+
returntype :: Type{T2} = eltype(df),
1515
inplace :: Type{Val{T3}} = Val{true}) where {T1,T2,T3}
1616

1717
if fdtype!=Val{:forward} && typeof(fx)!=Void
@@ -26,10 +26,15 @@ function GradientCache(
2626
# need cache arrays for x1 (c1) and epsilon (c2) (both only if non-StridedArray)
2727
if fdtype!=Val{:complex} # complex-mode FD only needs one cache, for x+eps*im
2828
if typeof(x)<:StridedVector
29-
_c1 = nothing
30-
_c2 = nothing
31-
if typeof(c1)!=Void || typeof(c2)!=Void
32-
warn("For StridedVectors, neither c1 nor c2 are necessary.")
29+
if eltype(df)<:Complex && !(eltype(x)<:Complex)
30+
_c1 = zeros(Complex{eltype(x)}, size(x))
31+
_c2 = nothing
32+
else
33+
_c1 = nothing
34+
_c2 = nothing
35+
if typeof(c1)!=Void || typeof(c2)!=Void
36+
warn("For StridedVectors, neither c1 nor c2 are necessary.")
37+
end
3338
end
3439
else
3540
if typeof(c1)!=typeof(x) || size(c1)!=size(x)
@@ -108,7 +113,7 @@ function finite_difference_gradient(f, x, fdtype::Type{T1}=Val{:central},
108113
end
109114

110115
function finite_difference_gradient!(df, f, x, fdtype::Type{T1}=Val{:central},
111-
returntype::Type{T2}=eltype(x), inplace::Type{Val{T3}}=Val{true},
116+
returntype::Type{T2}=eltype(df), inplace::Type{Val{T3}}=Val{true},
112117
fx::Union{Void,AbstractArray{<:Number}}=nothing,
113118
c1::Union{Void,AbstractArray{<:Number}}=nothing,
114119
c2::Union{Void,AbstractArray{<:Number}}=nothing,
@@ -130,25 +135,6 @@ function finite_difference_gradient(f,x,
130135
df
131136
end
132137

133-
#=
134-
function f1(df,f,x,epsilon)
135-
for i in eachindex(x)
136-
x0=x[i]
137-
x[i]+=epsilon
138-
dfi=f(x)
139-
x[i]=x0
140-
dfi-=f(x)
141-
df[i]=real(dfi/epsilon)
142-
x[i]+=im*epsilon
143-
dfi=f(x)
144-
x[i]=x0
145-
dfi-=f(x)
146-
df[i]+=im*imag(dfi/(im*epsilon))
147-
end
148-
df
149-
end
150-
=#
151-
152138
# vector of derivatives of a vector->scalar map by each component of a vector x
153139
# this ignores the value of "inplace", because it doesn't make much sense
154140
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
@@ -163,19 +149,19 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
163149
copy!(c1,x)
164150
end
165151
if fdtype == Val{:forward}
166-
if eltype(df)<:Complex || returntype<:Complex || eltype(x)<:Complex
167-
for i eachindex(x)
168-
epsilon = c2[i]
169-
c1_old = c1[i]
170-
c1[i] += epsilon
171-
if typeof(fx) != Void
172-
dfi = (f(c1) - fx) / epsilon
173-
else
174-
fx0 = f(x)
175-
dfi = (f(c1) - fx0) / epsilon
176-
end
177-
df[i] = real(dfi)
178-
c1[i] = c1_old
152+
@inbounds for i eachindex(x)
153+
epsilon = c2[i]
154+
c1_old = c1[i]
155+
c1[i] += epsilon
156+
if typeof(fx) != Void
157+
dfi = (f(c1) - fx) / epsilon
158+
else
159+
fx0 = f(x)
160+
dfi = (f(c1) - fx0) / epsilon
161+
end
162+
df[i] = real(dfi)
163+
c1[i] = c1_old
164+
if eltype(df)<:Complex
179165
c1[i] += im * epsilon
180166
if typeof(fx) != Void
181167
dfi = (f(c1) - fx) / (im*epsilon)
@@ -185,18 +171,6 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
185171
c1[i] = c1_old
186172
df[i] += im * imag(dfi)
187173
end
188-
else
189-
@inbounds for i eachindex(x)
190-
epsilon = c2[i]
191-
c1_old = c1[i]
192-
c1[i] += epsilon
193-
if typeof(fx) != Void
194-
df[i] = (f(c1) - fx) / epsilon
195-
else
196-
df[i] = (f(c1) - f(x)) / epsilon
197-
end
198-
c1[i] = c1_old
199-
end
200174
end
201175
elseif fdtype == Val{:central}
202176
@inbounds for i eachindex(x)
@@ -205,9 +179,16 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
205179
c1[i] += epsilon
206180
x_old = x[i]
207181
x[i] -= epsilon
208-
df[i] = (f(c1) - f(x)) / (2*epsilon)
182+
df[i] = real((f(c1) - f(x)) / (2*epsilon))
209183
c1[i] = c1_old
210184
x[i] = x_old
185+
if eltype(df)<:Complex
186+
c1[i] += im*epsilon
187+
x[i] -= im*epsilon
188+
df[i] += im*imag( (f(c1) - f(x)) / (2*im*epsilon) )
189+
c1[i] = c1_old
190+
x[i] = x_old
191+
end
211192
end
212193
elseif fdtype == Val{:complex} && returntype <: Real
213194
copy!(c1,x)
@@ -228,80 +209,77 @@ end
228209
function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedVector{<:Number},
229210
cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace}) where {T1,T2,T3,fdtype,returntype,inplace}
230211

231-
# c1 is x1, c2 shouldn't exist in this case
212+
# c1 is x1 if we need a complex copy of x, otherwise Void
213+
# c2 is Void
232214
fx, c1, c2 = cache.fx, cache.c1, cache.c2
233215
if fdtype != Val{:complex}
234216
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
217+
if eltype(df)<:Complex && !(eltype(x)<:Complex)
218+
copy!(c1,x)
219+
end
235220
end
236221
if fdtype == Val{:forward}
237-
if eltype(df)<:Complex || returntype<:Complex || eltype(x)<:Complex
238-
for i eachindex(x)
239-
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
240-
x_old = x[i]
241-
if typeof(fx) != Void
242-
x[i] += epsilon
243-
dfi = (f(x) - fx) / epsilon
244-
x[i] = x_old
245-
else
246-
fx0 = f(x)
247-
x[i] += epsilon
248-
dfi = (f(x) - fx0) / epsilon
249-
x[i] = x_old
250-
end
251-
df[i] = real(dfi)
252-
x[i] += im * epsilon
253-
if typeof(fx) != Void
254-
dfi = (f(x) - fx) / (im*epsilon)
255-
else
256-
dfi = (f(x) - fx0) / (im*epsilon)
257-
end
222+
for i eachindex(x)
223+
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
224+
x_old = x[i]
225+
if typeof(fx) != Void
226+
x[i] += epsilon
227+
dfi = (f(x) - fx) / epsilon
258228
x[i] = x_old
259-
df[i] += im * imag(dfi)
260-
end
261-
else
262-
@inbounds for i eachindex(x)
263-
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
264-
x_old = x[i]
229+
else
230+
fx0 = f(x)
265231
x[i] += epsilon
266-
dfi = f(x)
232+
dfi = (f(x) - fx0) / epsilon
267233
x[i] = x_old
268-
if typeof(fx) != Void
269-
dfi -= fx
234+
end
235+
df[i] = real(dfi)
236+
if eltype(df)<:Complex
237+
if eltype(x)<:Complex
238+
x[i] += im * epsilon
239+
if typeof(fx) != Void
240+
dfi = (f(x) - fx) / (im*epsilon)
241+
else
242+
dfi = (f(x) - fx0) / (im*epsilon)
243+
end
244+
x[i] = x_old
270245
else
271-
dfi -= f(x)
246+
c1[i] += im * epsilon
247+
if typeof(fx) != Void
248+
dfi = (f(c1) - fx) / (im*epsilon)
249+
else
250+
dfi = (f(c1) - fx0) / (im*epsilon)
251+
end
252+
c1[i] = x_old
272253
end
273-
df[i] = dfi / epsilon
254+
df[i] += im * imag(dfi)
274255
end
275256
end
276257
elseif fdtype == Val{:central}
277-
if eltype(df)<:Complex || returntype<:Complex || eltype(x)<:Complex
278-
@inbounds for i eachindex(x)
279-
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
280-
x_old = x[i]
281-
x[i] += epsilon
282-
dfi = f(x)
283-
x[i] = x_old - epsilon
284-
dfi -= f(x)
285-
x[i] = x_old
286-
df[i] = real(dfi / (2*epsilon))
287-
x[i] += im*epsilon
288-
dfi = f(x)
289-
x[i] = x_old - im*epsilon
290-
dfi -= f(x)
291-
x[i] = x_old
258+
@inbounds for i eachindex(x)
259+
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
260+
x_old = x[i]
261+
x[i] += epsilon
262+
dfi = f(x)
263+
x[i] = x_old - epsilon
264+
dfi -= f(x)
265+
x[i] = x_old
266+
df[i] = real(dfi / (2*epsilon))
267+
if eltype(df)<:Complex
268+
if eltype(x)<:Complex
269+
x[i] += im*epsilon
270+
dfi = f(x)
271+
x[i] = x_old - im*epsilon
272+
dfi -= f(x)
273+
x[i] = x_old
274+
else
275+
c1[i] += im*epsilon
276+
dfi = f(c1)
277+
c1[i] = x_old - im*epsilon
278+
dfi -= f(c1)
279+
c1[i] = x_old
280+
end
292281
df[i] += im*imag(dfi / (2*im*epsilon))
293282
end
294-
else
295-
@inbounds for i eachindex(x)
296-
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
297-
x_old = x[i]
298-
x[i] += epsilon
299-
dfi = f(x)
300-
x[i] = x_old - epsilon
301-
dfi -= f(x)
302-
x[i] = x_old
303-
df[i] = dfi / (2*epsilon)
304-
end
305283
end
306284
elseif fdtype==Val{:complex} && returntype<:Real && eltype(df)<:Real && eltype(x)<:Real
307285
copy!(c1,x)

test/finitedifftests.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,25 @@ central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:centr
145145
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, central_cache), df_ref) < 1e-8
146146
end
147147

148+
f(x) = 2*x[1] + im*x[2]^2
149+
x = ones(2)
150+
fx = f(x)
151+
df = zeros(eltype(fx), size(x))
152+
df_ref = [2.0, im*2*x[2]]
153+
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward},eltype(df))
154+
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central},eltype(df))
155+
156+
@time @testset "Gradient of f : R^N -> C tests" begin
157+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}, eltype(df)), df_ref) < 1e-4
158+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}, eltype(df)), df_ref) < 1e-8
159+
160+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}, eltype(df)), df_ref) < 1e-4
161+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}, eltype(df)), df_ref) < 1e-8
162+
163+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, forward_cache), df_ref) < 1e-4
164+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, central_cache), df_ref) < 1e-8
165+
end
166+
148167
f(df,x) = (df[1]=sin(x); df[2]=cos(x); df)
149168
x = 2π * rand()
150169
fx = zeros(2)
@@ -228,7 +247,7 @@ epsilon = zeros(real.(x))
228247
forward_cache = DiffEqDiffTools.JacobianCache(x,similar(x),similar(x),similar(x),Val{:forward})
229248
central_cache = DiffEqDiffTools.JacobianCache(x,similar(x),similar(x),similar(x))
230249

231-
@time @testset "Jacobian StridedArray complex-valued tests" begin
250+
@time @testset "Jacobian StridedArray f : C^N -> C^N tests" begin
232251
@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x, forward_cache), J_ref) < 1e-4
233252
@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x, central_cache), J_ref) < 1e-8
234253
@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x), J_ref) < 1e-8

0 commit comments

Comments
 (0)