Skip to content

Commit 114b946

Browse files
Merge pull request #17 from JuliaDiffEq/complexdiff
Support for derivatives and Jacobians of complex-valued callables
2 parents ff050f4 + 9a8afdc commit 114b946

File tree

6 files changed

+678
-262
lines changed

6 files changed

+678
-262
lines changed

src/derivatives.jl

+238
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
#=
2+
Compute the derivative df of a callable f on a collection of points x.
3+
Generic fallbacks for AbstractArrays that are not StridedArrays.
4+
=#
5+
function finite_difference(f, x::AbstractArray{<:Number},
6+
fdtype::DataType=Val{:central}, funtype::DataType=Val{:Real}, wrappertype::DataType=Val{:Default},
7+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
8+
9+
df = zeros(return_type, size(x))
10+
finite_difference!(df, f, x, fdtype, funtype, wrappertype, fx, epsilon, return_type)
11+
end
12+
13+
function finite_difference!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
14+
fdtype::DataType=Val{:central}, funtype::DataType=Val{:Real}, wrappertype::DataType=Val{:Default},
15+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
16+
17+
finite_difference!(df, f, x, fdtype, funtype, wrappertype, fx, return_type)
18+
end
19+
20+
# Fallbacks for real-valued callables start here.
21+
function finite_difference!(df::AbstractArray{<:Real}, f, x::AbstractArray{<:Real},
22+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
23+
fx::Union{Void,AbstractArray{<:Real}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
24+
25+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
26+
if typeof(epsilon) == Void
27+
epsilon = zeros(epsilon_elemtype, size(x))
28+
end
29+
if fdtype == Val{:forward}
30+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
31+
@. epsilon = compute_epsilon(Val{:forward}, x, epsilon_factor)
32+
if typeof(fx) == Void
33+
@. df = (f(x+epsilon) - f(x)) / epsilon
34+
else
35+
@. df = (f(x+epsilon) - fx) / epsilon
36+
end
37+
elseif fdtype == Val{:central}
38+
epsilon_factor = compute_epsilon_factor(Val{:central}, eltype(x))
39+
@. epsilon = compute_epsilon(Val{:central}, x, epsilon_factor)
40+
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
41+
elseif fdtype == Val{:complex}
42+
epsilon_complex = eps(epsilon_elemtype)
43+
@. df = imag(f(x+im*epsilon_complex)) / epsilon_complex
44+
else
45+
fdtype_error(Val{:Real})
46+
end
47+
df
48+
end
49+
# Fallbacks for real-valued callables end here.
50+
51+
# Fallbacks for complex-valued callables start here.
52+
function finite_difference!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
53+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
54+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
55+
56+
if (fdtype == Val{:forward} || fdtype == Val{:central}) && typeof(epsilon) == Void
57+
if eltype(x) <: Real
58+
epsilon = zeros(eltype(x), size(x))
59+
else
60+
epsilon = zeros(eltype(real(x)), size(x))
61+
end
62+
end
63+
if fdtype == Val{:forward}
64+
epsilon_factor = compute_epsilon_factor(Val{:forward}, eltype(epsilon))
65+
@. epsilon = compute_epsilon(Val{:forward}, real(x), epsilon_factor)
66+
if typeof(fx) == Void
67+
fx = f.(x)
68+
end
69+
@. df = real((f(x+epsilon) - fx)) / epsilon + im*imag((f(x+im*epsilon) - fx)) / epsilon
70+
elseif fdtype == Val{:central}
71+
epsilon_factor = compute_epsilon_factor(Val{:central}, eltype(epsilon))
72+
@. epsilon = compute_epsilon(Val{:central}, real(x), epsilon_factor)
73+
@. df = real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+im*epsilon) - f(x-epsilon)) / (2 * epsilon)
74+
else
75+
fdtype_error(Val{:Complex})
76+
end
77+
df
78+
end
79+
# Fallbacks for complex-valued callables end here.
80+
81+
#=
82+
Optimized implementations for StridedArrays.
83+
=#
84+
# for R -> R^n
85+
function finite_difference!(df::StridedArray{<:Real}, f, x::Real,
86+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
87+
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
88+
89+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
90+
if fdtype == Val{:forward}
91+
epsilon = compute_epsilon(Val{:forward}, x)
92+
if typeof(fx) == Void
93+
df .= (f(x+epsilon) - f(x)) / epsilon
94+
else
95+
df .= (f(x+epsilon) - fx) / epsilon
96+
end
97+
elseif fdtype == Val{:central}
98+
epsilon = compute_epsilon(Val{:central}, x)
99+
df .= (f(x+epsilon) - f(x-epsilon)) / (2*epsilon)
100+
elseif fdtype == Val{:complex}
101+
epsilon = eps(eltype(x))
102+
df .= imag(f(x+im*epsilon)) / epsilon
103+
else
104+
fdtype_error(Val{:Real})
105+
end
106+
df
107+
end
108+
109+
# for R^n -> R^n
110+
function finite_difference!(df::StridedArray{<:Real}, f, x::StridedArray{<:Real},
111+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
112+
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
113+
114+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
115+
if fdtype == Val{:forward}
116+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
117+
@inbounds for i in 1 : length(x)
118+
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
119+
x_plus = x[i] + epsilon
120+
if typeof(fx) == Void
121+
df[i] = (f(x_plus) - f(x[i])) / epsilon
122+
else
123+
df[i] = (f(x_plus) - fx[i]) / epsilon
124+
end
125+
end
126+
elseif fdtype == Val{:central}
127+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
128+
@inbounds for i in 1 : length(x)
129+
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
130+
epsilon_double_inv = one(typeof(epsilon)) / (2*epsilon)
131+
x_plus, x_minus = x[i]+epsilon, x[i]-epsilon
132+
df[i] = (f(x_plus) - f(x_minus)) * epsilon_double_inv
133+
end
134+
elseif fdtype == Val{:complex}
135+
epsilon_complex = eps(eltype(x))
136+
@inbounds for i in 1 : length(x)
137+
df[i] = imag(f(x[i]+im*epsilon_complex)) / epsilon_complex
138+
end
139+
else
140+
fdtype_error(Val{:Real})
141+
end
142+
df
143+
end
144+
145+
# C -> C^n
146+
function finite_difference!(df::StridedArray{<:Number}, f, x::Number,
147+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
148+
fx::Union{Void,StridedArray{<:Number}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
149+
150+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
151+
if fdtype == Val{:forward}
152+
epsilon = compute_epsilon(Val{:forward}, real(x[i]))
153+
if typeof(fx) == Void
154+
df .= ( real( f(x+epsilon) - f(x) ) + im*imag( f(x+im*epsilon) - f(x) ) ) / epsilon
155+
else
156+
df .= ( real( f(x+epsilon) - fx ) + im*imag( f(x+im*epsilon) - fx )) / epsilon
157+
end
158+
elseif fdtype == Val{:central}
159+
epsilon = compute_epsilon(Val{:central}, real(x[i]))
160+
df .= (real(f(x+epsilon) - f(x-epsilon)) + im*imag(f(x+im*epsilon) - f(x-im*epsilon))) / (2 * epsilon)
161+
else
162+
fdtype_error(Val{:Complex})
163+
end
164+
df
165+
end
166+
167+
# C^n -> C^n
168+
function finite_difference!(df::StridedArray{<:Number}, f, x::StridedArray{<:Number},
169+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
170+
fx::Union{Void,StridedArray{<:Number}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
171+
172+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
173+
if fdtype == Val{:forward}
174+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
175+
@inbounds for i in 1 : length(x)
176+
epsilon = compute_epsilon(Val{:forward}, real(x[i]), epsilon_factor)
177+
if typeof(fx) == Void
178+
df[i] = ( real( f(x[i]+epsilon) - f(x[i]) ) + im*imag( f(x[i]+im*epsilon) - f(x[i]) ) ) / epsilon
179+
else
180+
df[i] = ( real( f(x[i]+epsilon) - fx[i] ) + im*imag( f(x[i]+im*epsilon) - fx[i] )) / epsilon
181+
end
182+
end
183+
elseif fdtype == Val{:central}
184+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
185+
@inbounds for i in 1 : length(x)
186+
epsilon = compute_epsilon(Val{:central}, real(x[i]), epsilon_factor)
187+
df[i] = (real(f(x[i]+epsilon) - f(x[i]-epsilon)) + im*imag(f(x[i]+im*epsilon) - f(x[i]-im*epsilon))) / (2 * epsilon)
188+
end
189+
else
190+
fdtype_error(Val{:Complex})
191+
end
192+
df
193+
end
194+
195+
#=
196+
Compute the derivative df of a callable f on a collection of points x.
197+
Single point implementations.
198+
=#
199+
function finite_difference(f, x::T, fdtype::DataType, funtype::DataType=Val{:Real}, f_x::Union{Void,T}=nothing) where T<:Number
200+
if funtype == Val{:Real}
201+
if fdtype == Val{:complex}
202+
epsilon = eps(T)
203+
return imag(f(x+im*epsilon)) / epsilon
204+
else
205+
epsilon = compute_epsilon(fdtype, x)
206+
return finite_difference_kernel(f, x, fdtype, funtype, epsilon, f_x)
207+
end
208+
elseif funtype == Val{:Complex}
209+
epsilon = compute_epsilon(fdtype, real(x))
210+
return finite_difference_kernel(f, x, fdtype, funtype, epsilon, f_x)
211+
else
212+
fdtype_error(funtype)
213+
end
214+
end
215+
216+
@inline function finite_difference_kernel(f, x::T, ::Type{Val{:forward}}, ::Type{Val{:Real}}, epsilon::T, fx::Union{Void,T}=nothing) where T<:Real
217+
if typeof(fx) == Void
218+
return (f(x+epsilon) - f(x)) / epsilon
219+
else
220+
return (f(x+epsilon) - fx) / epsilon
221+
end
222+
end
223+
224+
@inline function finite_difference_kernel(f, x::T, ::Type{Val{:central}}, ::Type{Val{:Real}}, epsilon::T, ::Union{Void,T}=nothing) where T<:Real
225+
(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
226+
end
227+
228+
@inline function finite_difference_kernel(f, x::Number, ::Type{Val{:forward}}, ::Type{Val{:Complex}}, epsilon::Real, fx::Union{Void,<:Number}=nothing)
229+
if typeof(fx) == Void
230+
return real((f(x[i]+epsilon) - f(x[i]))) / epsilon + im*imag((f(x[i]+im*epsilon) - f(x[i]))) / epsilon
231+
else
232+
return real((f(x[i]+epsilon) - fx[i])) / epsilon + im*imag((f(x[i]+im*epsilon) - fx[i])) / epsilon
233+
end
234+
end
235+
236+
@inline function finite_difference_kernel(f, x::Number, ::Type{Val{:central}}, ::Type{Val{:Complex}}, epsilon::Real, fx::Union{Void,<:Number}=nothing)
237+
real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+im*epsilon) - f(x-im*epsilon)) / (2 * epsilon)
238+
end

src/diffeqwrappers.jl

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
function finite_difference!(df::AbstractArray{<:Number}, f, x::Union{Number,AbstractArray{<:Number}},
2+
fdtype::DataType, funtype::DataType, ::Type{Val{:DiffEqDerivativeWrapper}},
3+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
4+
5+
# TODO: optimized implementations for specific wrappers using the added DiffEq caching where appopriate
6+
7+
finite_difference!(df, f, x, fdtype, funtype, Val{:Default}, fx, epsilon, return_type)
8+
df
9+
end
10+
11+
function finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f, x::AbstractArray{<:Real},
12+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:JacobianWrapper}},
13+
fx::AbstractArray{<:Real}, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
14+
15+
m, n = size(J)
16+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
17+
x1, fx1 = f.x1, f.fx1
18+
copy!(x1, x)
19+
if fdtype == Val{:forward}
20+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
21+
@inbounds for i 1:n
22+
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
23+
x1[i] += epsilon
24+
f(fx1, x1)
25+
f(fx, x)
26+
@. J[:,i] = (fx1 - fx) / epsilon
27+
x1[i] -= epsilon
28+
end
29+
elseif fdtype == Val{:central}
30+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
31+
@inbounds for i 1:n
32+
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
33+
x1[i] += epsilon
34+
x[i] -= epsilon
35+
f(fx1, x1)
36+
f(fx, x)
37+
@. J[:,i] = (fx1 - fx) / (2*epsilon)
38+
x1[i] -= epsilon
39+
x[i] += epsilon
40+
end
41+
elseif fdtype == Val{:complex}
42+
x0 = Complex{eltype(x)}(x)
43+
epsilon = eps(eltype(x))
44+
@inbounds for i 1:n
45+
x0[i] += im * epsilon
46+
@. J[:,i] = imag(f(x0)) / epsilon
47+
x0[i] -= im * epsilon
48+
end
49+
else
50+
fdtype_error(Val{:Real})
51+
end
52+
J
53+
end
54+
55+
function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
56+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:JacobianWrapper}},
57+
fx::AbstractArray{<:Number}, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
58+
59+
# TODO: test this
60+
m, n = size(J)
61+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
62+
x1, fx1 = f.x1, f.fx1
63+
copy!(x1, x)
64+
if fdtype == Val{:forward}
65+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
66+
@inbounds for i 1:n
67+
epsilon = compute_epsilon(Val{:forward}, real(x[i]), epsilon_factor)
68+
x1[i] += epsilon
69+
f(fx1, x1)
70+
f(fx, x)
71+
@. J[:,i] = ( real( (fx1 - fx) ) + im*imag( (fx1 - fx) ) ) / epsilon
72+
x1[i] -= epsilon
73+
end
74+
elseif fdtype == Val{:central}
75+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
76+
@inbounds for i 1:n
77+
epsilon = compute_epsilon(Val{:central}, real(x[i]), epsilon_factor)
78+
x1[i] += epsilon
79+
x[i] -= epsilon
80+
f(fx1, x1)
81+
f(fx, x)
82+
@. J[:,i] = ( real( (fx1 - fx) ) + im*imag( fx1 - fx ) ) / (2*epsilon)
83+
x1[i] -= epsilon
84+
x[i] += epsilon
85+
end
86+
else
87+
fdtype_error(Val{:Complex})
88+
end
89+
J
90+
end

0 commit comments

Comments
 (0)