Skip to content

Commit af1dbbe

Browse files
dextoriousChrisRackauckas
authored andcommitted
Fix state before refactoring everything.
1 parent 0d49880 commit af1dbbe

File tree

4 files changed

+83
-51
lines changed

4 files changed

+83
-51
lines changed

src/derivatives.jl

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,26 @@
11
#=
2-
Derivative of f : R -> R or f : C -> C at a single point x.
2+
Sinple-point derivatives of scalar->scalar maps.
33
=#
4-
function finite_difference_derivative(f, x::T, fdtype::DataType, funtype::DataType=Val{:Real},
5-
f_x::Union{Void,T}=nothing) where T<:Number
4+
function finite_difference_derivative(f, x::T, fdtype::DataType=Val{:central},
5+
returntype::DataType=eltype(x), f_x::Union{Void,T}=nothing) where T<:Number
66

7-
if funtype == Val{:Real}
8-
if fdtype == Val{:complex}
9-
epsilon = eps(T)
10-
else
11-
epsilon = compute_epsilon(fdtype, x)
12-
end
13-
elseif funtype == Val{:Complex}
14-
epsilon = compute_epsilon(fdtype, real(x))
15-
else
16-
fdtype_error(funtype)
7+
epsilon = compute_epsilon(fdtype, real(x))
8+
if fdtype == Val{:forward}
9+
return (f(x+epsilon) - f(x)) / epsilon
10+
elseif fdtype == Val{:central}
11+
return (f(x+epsilon) - f(x-epsilon)) / (2*epsilon)
12+
elseif fdtype == Val{:complex} && returntype == Val{:Real}
13+
return imag(f(x+im*epsilon)) / epsilon
1714
end
18-
19-
_finite_difference_kernel(f, x, fdtype, funtype, epsilon, f_x)
15+
fdtype_error(returntype)
2016
end
2117

2218
#=
23-
Finite difference kernels for single point derivatives of f : R -> R.
24-
These are currently underused because of inlining / broadcast issues.
19+
Finite difference kernels for single point derivatives.
20+
These are currently unused because of inlining / broadcast issues.
2521
Revisit this in Julia v0.7 / 1.0.
2622
=#
23+
#=
2724
@inline function _finite_difference_kernel(f, x::T, ::Type{Val{:forward}}, ::Type{Val{:Real}},
2825
epsilon::T, fx::Union{Void,T}=nothing) where T<:Real
2926
@@ -61,6 +58,7 @@ end
6158
6259
real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+im*epsilon) - f(x-im*epsilon)) / (2 * epsilon)
6360
end
61+
=#
6462
# Single point derivative implementations end here.
6563

6664

@@ -179,9 +177,9 @@ function _finite_difference_derivative!(df::AbstractArray{<:Number}, f, x::Abstr
179177
if typeof(fx) == Void
180178
fx = f.(x)
181179
end
182-
@. df = real((f(x+epsilon) - fx)) / epsilon + im*imag((f(x+im*epsilon) - fx)) / epsilon
180+
@. df = real((f(x+epsilon) - fx)) / epsilon + im*imag((f(x+epsilon) - fx)) / epsilon
183181
elseif fdtype == Val{:central}
184-
@. df = real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+im*epsilon) - f(x-im*epsilon)) / (2 * epsilon)
182+
@. df = real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
185183
else
186184
fdtype_error(Val{:Complex})
187185
end
@@ -242,13 +240,13 @@ function _finite_difference_derivative!(df::StridedArray{<:Number}, f, x::Stride
242240
else
243241
fxi = fx[i]
244242
end
245-
df[i] = ( real( f(x[i]+epsilon) - fxi ) + im*imag( f(x[i]+im*epsilon) - fxi ) ) / epsilon
243+
df[i] = ( real( f(x[i]+epsilon) - fxi ) + im*imag( f(x[i]+epsilon) - fxi ) ) / epsilon
246244
end
247245
elseif fdtype == Val{:central}
248246
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
249247
@inbounds for i in 1 : length(x)
250248
epsilon = compute_epsilon(Val{:central}, real(x[i]), epsilon_factor)
251-
df[i] = (real(f(x[i]+epsilon) - f(x[i]-epsilon)) + im*imag(f(x[i]+im*epsilon) - f(x[i]-im*epsilon))) / (2 * epsilon)
249+
df[i] = ( real( f(x[i]+epsilon) - f(x[i]-epsilon) ) + im*imag( f(x[i]+epsilon) - f(x[i]-epsilon) ) ) / (2 * epsilon)
252250
end
253251
else
254252
fdtype_error(Val{:Complex})

src/finitediff.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@ end
1414
eps_cbrt * max(one(T), abs(x))
1515
end
1616

17+
@inline function compute_epsilon(::Type{Val{:complex}}, x::T) where T<:Real
18+
eps(T)
19+
end
20+
1721
@inline function compute_epsilon_factor(fdtype::DataType, ::Type{T}) where T<:Number
1822
if fdtype==Val{:forward}
1923
return sqrt(eps(T))
2024
elseif fdtype==Val{:central}
2125
return cbrt(eps(T))
2226
else
23-
error("Unrecognized fdtype $fdtype: must be Val{:forward} or Val{:central}.")
27+
return one(T)
2428
end
2529
end
2630

src/gradients.jl

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -195,43 +195,25 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
195195
c2[i] += epsilon
196196
if typeof(fx) == Void
197197
df[i] = real(f(c2) - f(x)) / epsilon
198-
else
199-
df[i] = real(f(c2) - fx) / epsilon
200-
end
201-
c2[i] -= epsilon
202-
c2[i] += im*epsilon
203-
if typeof(fx) == Void
204198
df[i] += im*imag(f(c2) - f(x)) / epsilon
205199
else
200+
df[i] = real(f(c2) - fx) / epsilon
206201
df[i] += im*imag(f(c2) - fx) / epsilon
207202
end
208-
c2[i] -= im*epsilon
203+
c2[i] -= epsilon
209204
end
210205
elseif fdtype == Val{:central}
211206
@inbounds for i eachindex(x)
212207
epsilon = c1[i]
213208
c2[i] += epsilon
214209
x[i] -= epsilon
215210
df[i] = real(f(c2) - f(x)) / (2*epsilon)
216-
c2[i] -= c1[i]
217-
x[i] += c1[i]
218-
c2[i] += im*epsilon
219-
x[i] -= im*epsilon
220211
df[i] += im*imag(f(c2) - f(x)) / (2*epsilon)
221-
c2[i] -= im*epsilon
222-
x[i] += im*epsilon
223-
end
224-
elseif fdtype == Val{:complex}
225-
epsilon_elemtype = compute_epsilon_elemtype(nothing, x)
226-
epsilon_complex = eps(epsilon_elemtype)
227-
# we use c1 here to avoid typing issues with x
228-
@inbounds for i eachindex(x)
229-
c1[i] += im*epsilon_complex
230-
df[i] = imag(f(c1)) / epsilon_complex
231-
c1[i] -= im*epsilon_complex
212+
c2[i] -= c1[i]
213+
x[i] += c1[i]
232214
end
233215
else
234-
fdtype_error(Val{:Complex})
216+
fdtype_error(Val{:complex})
235217
end
236218
df
237219
end
@@ -241,6 +223,34 @@ end
241223
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
242224
cache::GradientCache{T1,T2,T3,fdtype,Val{:Complex}}) where {T1,T2,T3,fdtype}
243225

244-
error("Not implemented yet.")
226+
# NOTE: in this case epsilon is a scalar, we need two arrays for fx1 and fx2
227+
# c1 denotes fx1, c2 is fx2, sizes guaranteed by the cache constructor
228+
fx, c1, c2 = cache.fx, cache.c1, cache.c2
229+
230+
epsilon_elemtype = compute_epsilon_elemtype(nothing, x)
231+
if fdtype == Val{:forward}
232+
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
233+
epsilon = compute_epsilon(Val{:forward}, real(x), epsilon_factor)
234+
c1 .= f(x+epsilon)
235+
if typeof(fx) != Void
236+
@. df = real(c1 - fx) / epsilon
237+
@. df += im*imag(c1 - fx) / epsilon
238+
else
239+
c2 .= f(x)
240+
@. df = real(c1 - c2) / epsilon
241+
@. df += im*imag(c1 - c2) / epsilon
242+
end
243+
elseif fdtype == Val{:central}
244+
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
245+
epsilon = compute_epsilon(Val{:central}, real(x), epsilon_factor)
246+
c1 .= f(x+epsilon)
247+
c2 .= f(x-epsilon)
248+
@. df = real(c1 - c2) / (2*epsilon)
249+
#c1 .= im*imag(f(x+im*epsilon))
250+
#c2 .= im*imag(f(x-im*epsilon))
251+
@. df += im*imag(c1 - c2) / (2*epsilon)
252+
elseif fdtype == Val{:complex}
253+
fdtype_error(Val{:complex})
254+
end
245255
df
246256
end

test/finitedifftests.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,19 @@ end
5252
end
5353

5454
x = x + im*x
55-
f(x) = cos(real(x)) + im*sin(imag(x))
55+
#f(x) = cos(real(x)) + im*sin(imag(x))
56+
f(x) = sin(x) + cos(x)
5657
y = f.(x)
5758
df = zeros(x)
5859
epsilon = zeros(length(x))
59-
df_ref = -sin.(real(x)) + im*cos.(imag(x))
60+
#df_ref = -sin.(real(x)) + im*cos.(imag(x))
61+
df_ref = cos.(x) - sin.(x)
6062
forward_cache = DiffEqDiffTools.DerivativeCache(x, y, epsilon, Val{:forward})
6163
central_cache = DiffEqDiffTools.DerivativeCache(x, nothing, epsilon, Val{:central})
6264

6365
@time @testset "Derivative single point complex-valued tests" begin
64-
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, π/4+im*π/4, Val{:forward}, Val{:Complex}), -√2/2 + im*2/2) < 1e-4
65-
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, π/4+im*π/4, Val{:central}, Val{:Complex}), -√2/2 + im*2/2) < 1e-8
66+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, π/4+im*π/4, Val{:forward}, Val{:Complex}), cos/4+im*π/4)-sin/4+im*π/4)) < 1e-4
67+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, π/4+im*π/4, Val{:central}, Val{:Complex}), cos/4+im*π/4)-sin/4+im*π/4)) < 1e-8
6668
end
6769

6870
@time @testset "Derivative StridedArray complex-valued tests" begin
@@ -157,8 +159,26 @@ complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:compl
157159
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, complex_cache), df_ref) < 1e-15
158160
end
159161

162+
f(x) = [sin(x), cos(x)]
163+
@show x = (2π * rand()) * (1 + im)
164+
@show fx = f(x)
165+
df = zeros(fx)
166+
@show df_ref = [cos(x), -sin(x)]
167+
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward},Val{:Real})
168+
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central},Val{:Real})
169+
complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:complex},Val{:Real})
170+
@show DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward})
171+
@show DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central})
172+
160173
@time @testset "Gradient of f:R->R^n complex-valued tests" begin
161-
# TODO
174+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
175+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}), df_ref) < 1e-7
176+
177+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}), df_ref) < 1e-4
178+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}), df_ref) < 1e-7
179+
180+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, forward_cache), df_ref) < 1e-4
181+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, central_cache), df_ref) < 1e-7
162182
end
163183

164184
# Jacobian tests

0 commit comments

Comments
 (0)