Skip to content

Commit b76ba94

Browse files
committed
Add complex step differentiation
1 parent e16f147 commit b76ba94

File tree

2 files changed

+51
-36
lines changed

2 files changed

+51
-36
lines changed

src/finitediff.jl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ end
1414
eps_cbrt * max(one(T), abs(x))
1515
end
1616

17-
@inline function compute_epsilon{T<:Real}(::Type{Val{:complex}}, x::T, ::Union{Void,T}=nothing)
18-
eps(x)
19-
end
20-
2117
@inline function compute_epsilon_factor{T<:Real}(fdtype::DataType, ::Type{T})
2218
if fdtype==Val{:forward}
2319
return sqrt(eps(T))
@@ -40,48 +36,41 @@ function finite_difference{T<:Real}(f, x::AbstractArray{T}, fdtype::DataType, fx
4036
end
4137

4238
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::AbstractArray{T}, fdtype::DataType, fx::Union{Void,AbstractArray{T}}, ::Type{Val{:Default}})
43-
epsilon_factor = compute_epsilon_factor(fdtype, T)
44-
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
4539
if fdtype == Val{:forward}
40+
epsilon_factor = compute_epsilon_factor(fdtype, T)
41+
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
4642
if typeof(fx) == Void
4743
@. df = (f(x+epsilon) - f(x)) / epsilon
4844
else
4945
@. df = (f(x+epsilon) - fx) / epsilon
5046
end
5147
elseif fdtype == Val{:central}
48+
epsilon_factor = compute_epsilon_factor(fdtype, T)
49+
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
5250
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
53-
end
54-
df
55-
end
56-
57-
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::AbstractArray{T}, fdtype::DataType, fx::Union{Void,AbstractArray{T}}, ::Type{Val{:DiffEqDerivativeWrapper}})
58-
epsilon_factor = compute_epsilon_factor(fdtype, T)
59-
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
60-
error("Not implemented yet.")
61-
62-
if fdtype == Val{:forward}
63-
if typeof(fx) == Void
64-
65-
else
66-
67-
end
68-
elseif fdtype == Val{:central}
69-
51+
elseif fdtype == Val{:complex}
52+
epsilon = eps(T)
53+
@. df = imag(f(x+im*epsilon)) / epsilon
7054
end
7155
df
7256
end
7357

7458
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::T, fdtype::DataType, fx::AbstractArray{T}, ::Type{Val{:DiffEqDerivativeWrapper}})
75-
epsilon = compute_epsilon(fdtype, x)
7659
fx1 = f.fx1
7760
if fdtype == Val{:forward}
61+
epsilon = compute_epsilon(fdtype, x)
7862
f(fx, x)
7963
f(fx1, x+epsilon)
8064
@. df = (fx1 - fx) / epsilon
8165
elseif fdtype == Val{:central}
66+
epsilon = compute_epsilon(fdtype, x)
8267
f(fx, x-epsilon)
8368
f(fx1, x+epsilon)
8469
@. df = (fx1 - fx) / (2 * epsilon)
70+
elseif fdtype == Val{:complex}
71+
epsilon = eps(T)
72+
f(fx, f(x+im*epsilon))
73+
@. df = imag(fx) / epsilon
8574
end
8675
df
8776
end
@@ -128,8 +117,13 @@ Compute the derivative df of a real-valued callable f on a collection of points
128117
Single point implementations.
129118
=#
130119
function finite_difference{T<:Real}(f, x::T, fdtype::DataType, f_x::Union{Void,T}=nothing)
131-
epsilon = compute_epsilon(fdtype, x)
132-
finite_difference_kernel(f, x, fdtype, epsilon, f_x)
120+
if fdtype == Val{:complex}
121+
epsilon = eps(T)
122+
return imag(f(x+im*epsilon)) / epsilon
123+
else
124+
epsilon = compute_epsilon(fdtype, x)
125+
return finite_difference_kernel(f, x, fdtype, epsilon, f_x)
126+
end
133127
end
134128

135129
@inline function finite_difference_kernel{T<:Real}(f, x::T, ::Type{Val{:forward}}, epsilon::T, fx::Union{Void,T})
@@ -154,7 +148,7 @@ function finite_difference_jacobian{T<:Real}(f, x::AbstractArray{T}, fdtype::Dat
154148
if funtype==Val{:Default}
155149
fx = f.(x)
156150
elseif funtype==Val{:DiffEqJacobianWrapper}
157-
f(fx, x)
151+
fx = f(x)
158152
else
159153
error("Unrecognized funtype: must be Val{:Default} or Val{:DiffEqJacobianWrapper}.")
160154
end
@@ -225,6 +219,22 @@ function finite_difference_jacobian!{T<:Real}(J::StridedArray{T}, f, x::StridedA
225219
J
226220
end
227221

222+
function finite_difference_jacobian!{T<:Real}(J::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:complex}}, fx::StridedArray{T}, ::Type{Val{:Default}})
223+
m, n = size(J)
224+
epsilon = eps(T)
225+
epsilon_inv = one(T) / epsilon
226+
@inbounds for i in 1:n
227+
for j in 1:m
228+
if i==j
229+
J[j,i] = imag(f(x[j]+im*epsilon)) * epsilon_inv
230+
else
231+
J[j,i] = zero(T)
232+
end
233+
end
234+
end
235+
J
236+
end
237+
228238
# efficient implementations for OrdinaryDiffEq Jacobian wrappers, assuming the system function supplies StridedArrays
229239
function finite_difference_jacobian!{T<:Real}(J::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}}, fx::StridedArray{T}, ::Type{Val{:JacobianWrapper}})
230240
m, n = size(J)

test/finitedifftests.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,23 @@ x = collect(linspace(-2π, 2π, 100))
22
y = sin.(x)
33
df = zeros(100)
44
df_ref = cos.(x)
5+
J_ref = diagm(cos.(x))
6+
7+
err_func(a,b) = maximum(abs.(a-b))
58

69
# TODO: add tests for non-StridedArrays and with more complicated functions
710

811
# derivative tests
9-
@test maximum(abs.(DiffEqDiffTools.finite_difference(sin, x, Val{:forward}) - df_ref)) < 1e-4
10-
@test maximum(abs.(DiffEqDiffTools.finite_difference(sin, x, Val{:forward}, y) - df_ref)) < 1e-4
11-
@test maximum(abs.(DiffEqDiffTools.finite_difference(sin, x, Val{:central}) - df_ref)) < 1e-8
12-
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}, nothing, Val{:Default}) - df_ref)) < 1e-4
13-
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}, y, Val{:Default}) - df_ref)) < 1e-4
14-
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:central}, nothing, Val{:Default}) - df_ref)) < 1e-8
12+
@test err_func(DiffEqDiffTools.finite_difference(sin, x, Val{:forward}), df_ref) < 1e-4
13+
@test err_func(DiffEqDiffTools.finite_difference(sin, x, Val{:forward}, y), df_ref) < 1e-4
14+
@test err_func(DiffEqDiffTools.finite_difference(sin, x, Val{:central}), df_ref) < 1e-8
15+
@test err_func(DiffEqDiffTools.finite_difference(sin, x, Val{:complex}), df_ref) < 1e-15
16+
@test err_func(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}, nothing, Val{:Default}), df_ref) < 1e-4
17+
@test err_func(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}, y, Val{:Default}), df_ref) < 1e-4
18+
@test err_func(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:central}, nothing, Val{:Default}), df_ref) < 1e-8
19+
@test err_func(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:complex}, nothing, Val{:Default}), df_ref) < 1e-15
1520

1621
# Jacobian tests
17-
using Calculus
18-
@test DiffEqDiffTools.finite_difference_jacobian(sin, x, Val{:forward}) Calculus.finite_difference_jacobian(sin, x, :forward)
19-
@test DiffEqDiffTools.finite_difference_jacobian(sin, x, Val{:central}) Calculus.finite_difference_jacobian(sin, x, :central)
22+
@test err_func(DiffEqDiffTools.finite_difference_jacobian(sin, x, Val{:forward}), J_ref) < 1e-4
23+
@test err_func(DiffEqDiffTools.finite_difference_jacobian(sin, x, Val{:central}), J_ref) < 1e-8
24+
@test err_func(DiffEqDiffTools.finite_difference_jacobian(sin, x, Val{:complex}), J_ref) < 1e-15

0 commit comments

Comments
 (0)