Skip to content

Commit 73d249a

Browse files
committed
More tests and some refactoring.
1 parent 0e9ad46 commit 73d249a

File tree

6 files changed

+497
-470
lines changed

6 files changed

+497
-470
lines changed

src/derivatives.jl

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#=
2+
Compute the derivative df of a callable f on a collection of points x.
3+
Generic fallbacks for AbstractArrays that are not StridedArrays.
4+
=#
5+
function finite_difference(f, x::AbstractArray{<:Number},
6+
fdtype::DataType=Val{:central}, funtype::DataType=Val{:Real}, wrappertype::DataType=Val{:Default},
7+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
8+
9+
df = zeros(return_type, size(x))
10+
finite_difference!(df, f, x, fdtype, funtype, wrappertype, fx, epsilon, return_type)
11+
end
12+
13+
function finite_difference!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
14+
fdtype::DataType=Val{:central}, funtype::DataType=Val{:Real}, wrappertype::DataType=Val{:Default},
15+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
16+
17+
finite_difference!(df, f, x, fdtype, funtype, wrappertype, fx, return_type)
18+
end
19+
20+
# Fallbacks for real-valued callables start here.
21+
function finite_difference!(df::AbstractArray{<:Real}, f, x::AbstractArray{<:Real},
22+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
23+
fx::Union{Void,AbstractArray{<:Real}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
24+
25+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
26+
if typeof(epsilon) == Void
27+
epsilon = zeros(epsilon_elemtype, size(x))
28+
end
29+
if fdtype == Val{:forward}
30+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
31+
@. epsilon = compute_epsilon(Val{:forward}, x, epsilon_factor)
32+
if typeof(fx) == Void
33+
@. df = (f(x+epsilon) - f(x)) / epsilon
34+
else
35+
@. df = (f(x+epsilon) - fx) / epsilon
36+
end
37+
elseif fdtype == Val{:central}
38+
epsilon_factor = compute_epsilon_factor(Val{:central}, eltype(x))
39+
@. epsilon = compute_epsilon(Val{:central}, x, epsilon_factor)
40+
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
41+
elseif fdtype == Val{:complex}
42+
epsilon_complex = eps(epsilon_elemtype)
43+
@. df = imag(f(x+im*epsilon_complex)) / epsilon_complex
44+
else
45+
error("Unrecognized fdtype: valid values are Val{:forward}, Val{:central} and Val{:complex}.")
46+
end
47+
df
48+
end
49+
# Fallbacks for real-valued callables end here.
50+
51+
# Fallbacks for complex-valued callables start here.
52+
function finite_difference!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
53+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
54+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
55+
56+
if (fdtype == Val{:forward} || fdtype == Val{:central}) && typeof(epsilon) == Void
57+
if eltype(x) <: Real
58+
epsilon = zeros(eltype(x), size(x))
59+
else
60+
epsilon = zeros(eltype(real(x)), size(x))
61+
end
62+
end
63+
if fdtype == Val{:forward}
64+
@show typeof(x)
65+
epsilon_factor = compute_epsilon_factor(Val{:forward}, eltype(epsilon))
66+
@. epsilon = compute_epsilon(Val{:forward}, real(x), epsilon_factor)
67+
if typeof(fx) == Void
68+
fx = f.(x)
69+
end
70+
@. df = real((f(x+epsilon) - fx)) / epsilon + im*imag((f(x+im*epsilon) - fx)) / epsilon
71+
elseif fdtype == Val{:central}
72+
epsilon_factor = compute_epsilon_factor(Val{:central}, eltype(epsilon))
73+
@. epsilon = compute_epsilon(Val{:central}, real(x), epsilon_factor)
74+
@. df = real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+im*epsilon) - f(x-epsilon)) / (2 * epsilon)
75+
elseif fdtype == Val{:complex}
76+
error("Invalid fdtype value, Val{:complex} not implemented for complex-valued functions.")
77+
end
78+
df
79+
end
80+
# Fallbacks for complex-valued callables end here.
81+
82+
#=
83+
Optimized implementations for StridedArrays.
84+
=#
85+
function finite_difference!(df::StridedArray{<:Real}, f, x::StridedArray{<:Real},
86+
::Type{Val{:central}}, ::Type{Val{:Real}}, ::Type{Val{:Default}},
87+
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
88+
89+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
90+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
91+
@inbounds for i in 1 : length(x)
92+
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
93+
epsilon_double_inv = one(typeof(epsilon)) / (2*epsilon)
94+
x_plus, x_minus = x[i]+epsilon, x[i]-epsilon
95+
df[i] = (f(x_plus) - f(x_minus)) * epsilon_double_inv
96+
end
97+
df
98+
end
99+
100+
function finite_difference!(df::StridedArray{<:Real}, f, x::StridedArray{<:Real},
101+
::Type{Val{:forward}}, ::Type{Val{:Real}}, ::Type{Val{:Default}},
102+
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
103+
104+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
105+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
106+
@inbounds for i in 1 : length(x)
107+
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
108+
x_plus = x[i] + epsilon
109+
if typeof(fx) == Void
110+
df[i] = (f(x_plus) - f(x[i])) / epsilon
111+
else
112+
df[i] = (f(x_plus) - fx[i]) / epsilon
113+
end
114+
end
115+
df
116+
end
117+
118+
function finite_difference!(df::StridedArray{<:Real}, f, x::StridedArray{<:Real},
119+
::Type{Val{:complex}}, ::Type{Val{:Real}}, ::Type{Val{:Default}},
120+
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
121+
122+
epsilon_complex = eps(eltype(x))
123+
@inbounds for i in 1 : length(x)
124+
df[i] = imag(f(x[i]+im*epsilon_complex)) / epsilon_complex
125+
end
126+
df
127+
end
128+
129+
function finite_difference!(df::StridedArray{<:Number}, f, x::StridedArray{<:Number},
130+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
131+
fx::Union{Void,StridedArray{<:Number}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
132+
133+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
134+
if fdtype == Val{:forward}
135+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
136+
@inbounds for i in 1 : length(x)
137+
epsilon = compute_epsilon(Val{:forward}, real(x[i]), epsilon_factor)
138+
if typeof(fx) == Void
139+
df[i] = ( real( f(x[i]+epsilon) - f(x[i]) ) + im*imag( f(x[i]+im*epsilon) - f(x[i]) ) ) / epsilon
140+
else
141+
df[i] = ( real( f(x[i]+epsilon) - fx[i] ) + im*imag( f(x[i]+im*epsilon) - fx[i] )) / epsilon
142+
end
143+
end
144+
elseif fdtype == Val{:central}
145+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
146+
@inbounds for i in 1 : length(x)
147+
epsilon = compute_epsilon(Val{:central}, real(x[i]), epsilon_factor)
148+
df[i] = (real(f(x[i]+epsilon) - f(x[i]-epsilon)) + im*imag(f(x[i]+im*epsilon) - f(x[i]-im*epsilon))) / (2 * epsilon)
149+
end
150+
elseif fdtype == Val{:complex}
151+
error("Invalid fdtype value, Val{:complex} not implemented for complex-valued functions.")
152+
end
153+
df
154+
end
155+
156+
#=
157+
Compute the derivative df of a callable f on a collection of points x.
158+
Single point implementations.
159+
=#
160+
function finite_difference(f, x::T, fdtype::DataType, funtype::DataType=Val{:Real}, f_x::Union{Void,T}=nothing) where T<:Number
161+
if funtype == Val{:Real}
162+
if fdtype == Val{:complex}
163+
epsilon = eps(T)
164+
return imag(f(x+im*epsilon)) / epsilon
165+
else
166+
epsilon = compute_epsilon(fdtype, x)
167+
return finite_difference_kernel(f, x, fdtype, funtype, epsilon, f_x)
168+
end
169+
elseif funtype == Val{:Complex}
170+
epsilon = compute_epsilon(fdtype, real(x))
171+
return finite_difference_kernel(f, x, fdtype, funtype, epsilon, f_x)
172+
end
173+
end
174+
175+
@inline function finite_difference_kernel(f, x::T, ::Type{Val{:forward}}, ::Type{Val{:Real}}, epsilon::T, fx::Union{Void,T}=nothing) where T<:Real
176+
if typeof(fx) == Void
177+
return (f(x+epsilon) - f(x)) / epsilon
178+
else
179+
return (f(x+epsilon) - fx) / epsilon
180+
end
181+
end
182+
183+
@inline function finite_difference_kernel(f, x::T, ::Type{Val{:central}}, ::Type{Val{:Real}}, epsilon::T, ::Union{Void,T}=nothing) where T<:Real
184+
(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
185+
end
186+
187+
@inline function finite_difference_kernel(f, x::Number, ::Type{Val{:forward}}, ::Type{Val{:Complex}}, epsilon::Real, fx::Union{Void,<:Number}=nothing)
188+
if typeof(fx) == Void
189+
return real((f(x[i]+epsilon) - f(x[i]))) / epsilon + im*imag((f(x[i]+im*epsilon) - fx[i])) / epsilon
190+
else
191+
return real((f(x[i]+epsilon) - fx[i])) / epsilon + im*imag((f(x[i]+im*epsilon) - fx[i])) / epsilon
192+
end
193+
end
194+
195+
@inline function finite_difference_kernel(f, x::Number, ::Type{Val{:central}}, ::Type{Val{:Complex}}, epsilon::Real, fx::Union{Void,<:Number}=nothing)
196+
real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+im*epsilon) - f(x-im*epsilon)) / (2 * epsilon)
197+
end

src/diffeqwrappers.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
function finite_difference!(df::AbstractArray{<:Real}, f, x::AbstractArray{<:Real},
2+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:DiffEqDerivativeWrapper}},
3+
fx::Union{Void,AbstractArray{<:Real}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
4+
5+
# TODO: test this one, and figure out what happens with epsilon
6+
fx1 = f.fx1
7+
if fdtype == Val{:forward}
8+
epsilon = compute_epsilon(Val{:forward}, x)
9+
f(fx, x)
10+
f(fx1, x+epsilon)
11+
@. df = (fx1 - fx) / epsilon
12+
elseif fdtype == Val{:central}
13+
epsilon = compute_epsilon(Val{:central}, x)
14+
f(fx, x-epsilon)
15+
f(fx1, x+epsilon)
16+
@. df = (fx1 - fx) / (2 * epsilon)
17+
elseif fdtype == Val{:complex}
18+
epsilon = eps(eltype(x))
19+
f(fx, f(x+im*epsilon))
20+
@. df = imag(fx) / epsilon
21+
end
22+
df
23+
end
24+
25+
# AbstractArray{T} should be OK if JacobianWrapper is provided
26+
function finite_difference_jacobian!(J::AbstractArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}}, fx::StridedArray{T}, ::Type{Val{:JacobianWrapper}}) where T<:Real
27+
m, n = size(J)
28+
epsilon_factor = compute_epsilon_factor(Val{:forward}, T)
29+
x1, fx1 = f.x1, f.fx1
30+
copy!(x1, x)
31+
copy!(fx1, fx)
32+
@inbounds for i in 1:n
33+
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
34+
epsilon_inv = one(T) / epsilon
35+
x1[i] += epsilon
36+
f(fx, x)
37+
f(fx1, x1)
38+
@. J[:,i] = (fx-fx1) * epsilon_inv
39+
x1[i] -= epsilon
40+
end
41+
J
42+
end
43+
44+
function finite_difference_jacobian!(J::AbstractArray{T}, f, x::StridedArray{T}, ::Type{Val{:central}}, fx::StridedArray{T}, ::Type{Val{:JacobianWrapper}}) where T<:Real
45+
m, n = size(J)
46+
epsilon_factor = compute_epsilon_factor(Val{:central}, T)
47+
x1, fx1 = f.x1, f.fx1
48+
copy!(x1, x)
49+
copy!(fx1, fx)
50+
@inbounds for i in 1:n
51+
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
52+
epsilon_double_inv = one(T) / (2 * epsilon)
53+
x[i] += epsilon
54+
x1[i] -= epsilon
55+
f(fx, x)
56+
f(fx1, x1)
57+
@. J[:,i] = (fx-fx1) * epsilon_double_inv
58+
x[i] -= epsilon
59+
x1[i] += epsilon
60+
end
61+
J
62+
end

0 commit comments

Comments
 (0)