Skip to content

Commit a75114c

Browse files
Merge pull request #24 from JuliaDiffEq/diffeqwrappers
fix diffeqwrapper derivatives
2 parents 5b65ff8 + 99f324c commit a75114c

File tree

2 files changed

+22
-45
lines changed

2 files changed

+22
-45
lines changed

src/derivatives.jl

+7-36
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@ function _finite_difference!(df::AbstractArray{<:Real}, f, x::AbstractArray{<:Re
2929
if fdtype == Val{:forward}
3030
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
3131
@. epsilon = compute_epsilon(Val{:forward}, x, epsilon_factor)
32-
if typeof(fx) == Void
33-
@. df = (f(x+epsilon) - f(x)) / epsilon
34-
else
35-
@. df = (f(x+epsilon) - fx) / epsilon
36-
end
32+
@. df = (f(x+epsilon) - f(x)) / epsilon
3733
elseif fdtype == Val{:central}
3834
epsilon_factor = compute_epsilon_factor(Val{:central}, eltype(x))
3935
@. epsilon = compute_epsilon(Val{:central}, x, epsilon_factor)
@@ -85,15 +81,10 @@ Optimized implementations for StridedArrays.
8581
function _finite_difference!(df::StridedArray{<:Real}, f, x::Real,
8682
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
8783
fx, epsilon, return_type)
88-
8984
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
9085
if fdtype == Val{:forward}
9186
epsilon = compute_epsilon(Val{:forward}, x)
92-
if typeof(fx) == Void
93-
df .= (f(x+epsilon) - f(x)) / epsilon
94-
else
95-
df .= (f(x+epsilon) - fx) / epsilon
96-
end
87+
df .= (f(x+epsilon) - f(x)) / epsilon
9788
elseif fdtype == Val{:central}
9889
epsilon = compute_epsilon(Val{:central}, x)
9990
df .= (f(x+epsilon) - f(x-epsilon)) / (2*epsilon)
@@ -117,11 +108,7 @@ function _finite_difference!(df::StridedArray{<:Real}, f, x::StridedArray{<:Real
117108
@inbounds for i in 1 : length(x)
118109
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
119110
x_plus = x[i] + epsilon
120-
if typeof(fx) == Void
121-
df[i] = (f(x_plus) - f(x[i])) / epsilon
122-
else
123-
df[i] = (f(x_plus) - fx[i]) / epsilon
124-
end
111+
df[i] = (f(x_plus) - f(x[i])) / epsilon
125112
end
126113
elseif fdtype == Val{:central}
127114
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
@@ -150,11 +137,7 @@ function _finite_difference!(df::StridedArray{<:Number}, f, x::Number,
150137
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
151138
if fdtype == Val{:forward}
152139
epsilon = compute_epsilon(Val{:forward}, real(x[i]))
153-
if typeof(fx) == Void
154-
df .= ( real( f(x+epsilon) - f(x) ) + im*imag( f(x+im*epsilon) - f(x) ) ) / epsilon
155-
else
156-
df .= ( real( f(x+epsilon) - fx ) + im*imag( f(x+im*epsilon) - fx )) / epsilon
157-
end
140+
df .= ( real( f(x+epsilon) - f(x) ) + im*imag( f(x+im*epsilon) - f(x) ) ) / epsilon
158141
elseif fdtype == Val{:central}
159142
epsilon = compute_epsilon(Val{:central}, real(x[i]))
160143
df .= (real(f(x+epsilon) - f(x-epsilon)) + im*imag(f(x+im*epsilon) - f(x-im*epsilon))) / (2 * epsilon)
@@ -174,11 +157,7 @@ function _finite_difference!(df::StridedArray{<:Number}, f, x::StridedArray{<:Nu
174157
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
175158
@inbounds for i in 1 : length(x)
176159
epsilon = compute_epsilon(Val{:forward}, real(x[i]), epsilon_factor)
177-
if typeof(fx) == Void
178-
df[i] = ( real( f(x[i]+epsilon) - f(x[i]) ) + im*imag( f(x[i]+im*epsilon) - f(x[i]) ) ) / epsilon
179-
else
180-
df[i] = ( real( f(x[i]+epsilon) - fx[i] ) + im*imag( f(x[i]+im*epsilon) - fx[i] )) / epsilon
181-
end
160+
df[i] = ( real( f(x[i]+epsilon) - f(x[i]) ) + im*imag( f(x[i]+im*epsilon) - f(x[i]) ) ) / epsilon
182161
end
183162
elseif fdtype == Val{:central}
184163
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
@@ -214,23 +193,15 @@ function finite_difference(f, x::T, fdtype::DataType, funtype::DataType=Val{:Rea
214193
end
215194

216195
@inline function finite_difference_kernel(f, x::T, ::Type{Val{:forward}}, ::Type{Val{:Real}}, epsilon::T, fx::Union{Void,T}=nothing) where T<:Real
217-
if typeof(fx) == Void
218-
return (f(x+epsilon) - f(x)) / epsilon
219-
else
220-
return (f(x+epsilon) - fx) / epsilon
221-
end
196+
return (f(x+epsilon) - f(x)) / epsilon
222197
end
223198

224199
@inline function finite_difference_kernel(f, x::T, ::Type{Val{:central}}, ::Type{Val{:Real}}, epsilon::T, ::Union{Void,T}=nothing) where T<:Real
225200
(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
226201
end
227202

228203
@inline function finite_difference_kernel(f, x::Number, ::Type{Val{:forward}}, ::Type{Val{:Complex}}, epsilon::Real, fx::Union{Void,<:Number}=nothing)
229-
if typeof(fx) == Void
230-
return real((f(x[i]+epsilon) - f(x[i]))) / epsilon + im*imag((f(x[i]+im*epsilon) - f(x[i]))) / epsilon
231-
else
232-
return real((f(x[i]+epsilon) - fx[i])) / epsilon + im*imag((f(x[i]+im*epsilon) - fx[i])) / epsilon
233-
end
204+
return real((f(x[i]+epsilon) - f(x[i]))) / epsilon + im*imag((f(x[i]+im*epsilon) - f(x[i]))) / epsilon
234205
end
235206

236207
@inline function finite_difference_kernel(f, x::Number, ::Type{Val{:central}}, ::Type{Val{:Complex}}, epsilon::Real, fx::Union{Void,<:Number}=nothing)

src/diffeqwrappers.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ end
1111
function finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f, x::AbstractArray{<:Real},
1212
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:JacobianWrapper}},
1313
fx::AbstractArray{<:Real}, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
14-
1514
m, n = size(J)
1615
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
1716
x1, fx1 = f.x1, f.fx1
@@ -21,31 +20,38 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f, x::AbstractAr
2120
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
2221
@inbounds for i 1:n
2322
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
23+
x1_save = x1[i]
2424
x1[i] += epsilon
2525
f(fx1, x1)
2626
f(fx, x)
27-
@. J[:,i] = (vfx1 - vfx) / epsilon
28-
x1[i] -= epsilon
27+
@. J[:,i] = (vfx - vfx1) / epsilon
28+
x1[i] = x1_save
2929
end
3030
elseif fdtype == Val{:central}
3131
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
3232
@inbounds for i 1:n
3333
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
34+
x1_save = x1[i]
35+
x_save = x[i]
3436
x1[i] += epsilon
3537
x[i] -= epsilon
3638
f(fx1, x1)
3739
f(fx, x)
38-
@. J[:,i] = (vfx1 - vfx) / (2*epsilon)
39-
x1[i] -= epsilon
40-
x[i] += epsilon
40+
@. J[:,i] = (vfx - vfx1) / (2*epsilon)
41+
x1[i] = x1_save
42+
x[i] = x_save
4143
end
4244
elseif fdtype == Val{:complex}
43-
x0 = Complex{eltype(x)}(x)
45+
x0 = Complex{eltype(x)}.(x)
46+
cfx1 = Complex{eltype(x)}.(fx1)
47+
vcfx1 = vec(cfx1)
4448
epsilon = eps(eltype(x))
4549
@inbounds for i 1:n
50+
x0_save = x0[i]
4651
x0[i] += im * epsilon
47-
@. J[:,i] = imag(f(x0)) / epsilon
48-
x0[i] -= im * epsilon
52+
f(cfx1,x0)
53+
@. J[:,i] = imag(vcfx1) / epsilon # Fix allocation
54+
x0[i] = x0_save
4955
end
5056
else
5157
fdtype_error(Val{:Real})

0 commit comments

Comments
 (0)