Skip to content

Commit 4803b91

Browse files
committed
Preliminary work on complex derivatives complete, everything should at least work.
1 parent 73d249a commit 4803b91

File tree

4 files changed

+159
-83
lines changed

4 files changed

+159
-83
lines changed

src/derivatives.jl

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function finite_difference!(df::AbstractArray{<:Real}, f, x::AbstractArray{<:Rea
4242
epsilon_complex = eps(epsilon_elemtype)
4343
@. df = imag(f(x+im*epsilon_complex)) / epsilon_complex
4444
else
45-
error("Unrecognized fdtype: valid values are Val{:forward}, Val{:central} and Val{:complex}.")
45+
fdtype_error(Val{:Real})
4646
end
4747
df
4848
end
@@ -61,7 +61,6 @@ function finite_difference!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:N
6161
end
6262
end
6363
if fdtype == Val{:forward}
64-
@show typeof(x)
6564
epsilon_factor = compute_epsilon_factor(Val{:forward}, eltype(epsilon))
6665
@. epsilon = compute_epsilon(Val{:forward}, real(x), epsilon_factor)
6766
if typeof(fx) == Void
@@ -72,8 +71,8 @@ function finite_difference!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:N
7271
epsilon_factor = compute_epsilon_factor(Val{:central}, eltype(epsilon))
7372
@. epsilon = compute_epsilon(Val{:central}, real(x), epsilon_factor)
7473
@. df = real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+im*epsilon) - f(x-epsilon)) / (2 * epsilon)
75-
elseif fdtype == Val{:complex}
76-
error("Invalid fdtype value, Val{:complex} not implemented for complex-valued functions.")
74+
else
75+
fdtype_error(Val{:Complex})
7776
end
7877
df
7978
end
@@ -82,50 +81,90 @@ end
8281
#=
8382
Optimized implementations for StridedArrays.
8483
=#
85-
function finite_difference!(df::StridedArray{<:Real}, f, x::StridedArray{<:Real},
86-
::Type{Val{:central}}, ::Type{Val{:Real}}, ::Type{Val{:Default}},
84+
# for R -> R^n
85+
function finite_difference!(df::StridedArray{<:Real}, f, x::Real,
86+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
8787
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
8888

8989
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
90-
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
91-
@inbounds for i in 1 : length(x)
92-
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
93-
epsilon_double_inv = one(typeof(epsilon)) / (2*epsilon)
94-
x_plus, x_minus = x[i]+epsilon, x[i]-epsilon
95-
df[i] = (f(x_plus) - f(x_minus)) * epsilon_double_inv
90+
if fdtype == Val{:forward}
91+
epsilon = compute_epsilon(Val{:forward}, x)
92+
if typeof(fx) == Void
93+
df .= (f(x+epsilon) - f(x)) / epsilon
94+
else
95+
df .= (f(x+epsilon) - fx) / epsilon
96+
end
97+
elseif fdtype == Val{:central}
98+
epsilon = compute_epsilon(Val{:central}, x)
99+
df .= (f(x+epsilon) - f(x-epsilon)) / (2*epsilon)
100+
elseif fdtype == Val{:complex}
101+
epsilon = eps(eltype(x))
102+
df .= imag(f(x+im*epsilon)) / epsilon
103+
else
104+
fdtype_error(Val{:Real})
96105
end
97106
df
98107
end
99108

109+
# for R^n -> R^n
100110
function finite_difference!(df::StridedArray{<:Real}, f, x::StridedArray{<:Real},
101-
::Type{Val{:forward}}, ::Type{Val{:Real}}, ::Type{Val{:Default}},
111+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
102112
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
103113

104114
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
105-
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
106-
@inbounds for i in 1 : length(x)
107-
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
108-
x_plus = x[i] + epsilon
109-
if typeof(fx) == Void
110-
df[i] = (f(x_plus) - f(x[i])) / epsilon
111-
else
112-
df[i] = (f(x_plus) - fx[i]) / epsilon
115+
if fdtype == Val{:forward}
116+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
117+
@inbounds for i in 1 : length(x)
118+
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
119+
x_plus = x[i] + epsilon
120+
if typeof(fx) == Void
121+
df[i] = (f(x_plus) - f(x[i])) / epsilon
122+
else
123+
df[i] = (f(x_plus) - fx[i]) / epsilon
124+
end
125+
end
126+
elseif fdtype == Val{:central}
127+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
128+
@inbounds for i in 1 : length(x)
129+
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
130+
epsilon_double_inv = one(typeof(epsilon)) / (2*epsilon)
131+
x_plus, x_minus = x[i]+epsilon, x[i]-epsilon
132+
df[i] = (f(x_plus) - f(x_minus)) * epsilon_double_inv
133+
end
134+
elseif fdtype == Val{:complex}
135+
epsilon_complex = eps(eltype(x))
136+
@inbounds for i in 1 : length(x)
137+
df[i] = imag(f(x[i]+im*epsilon_complex)) / epsilon_complex
113138
end
139+
else
140+
fdtype_error(Val{:Real})
114141
end
115142
df
116143
end
117144

118-
function finite_difference!(df::StridedArray{<:Real}, f, x::StridedArray{<:Real},
119-
::Type{Val{:complex}}, ::Type{Val{:Real}}, ::Type{Val{:Default}},
120-
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
145+
# C -> C^n
146+
function finite_difference!(df::StridedArray{<:Number}, f, x::Number,
147+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
148+
fx::Union{Void,StridedArray{<:Number}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
121149

122-
epsilon_complex = eps(eltype(x))
123-
@inbounds for i in 1 : length(x)
124-
df[i] = imag(f(x[i]+im*epsilon_complex)) / epsilon_complex
150+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
151+
if fdtype == Val{:forward}
152+
epsilon = compute_epsilon(Val{:forward}, real(x[i]))
153+
if typeof(fx) == Void
154+
df .= ( real( f(x+epsilon) - f(x) ) + im*imag( f(x+im*epsilon) - f(x) ) ) / epsilon
155+
else
156+
df .= ( real( f(x+epsilon) - fx ) + im*imag( f(x+im*epsilon) - fx )) / epsilon
157+
end
158+
elseif fdtype == Val{:central}
159+
epsilon = compute_epsilon(Val{:central}, real(x[i]))
160+
df .= (real(f(x+epsilon) - f(x-epsilon)) + im*imag(f(x+im*epsilon) - f(x-im*epsilon))) / (2 * epsilon)
161+
else
162+
fdtype_error(Val{:Complex})
125163
end
126164
df
127165
end
128166

167+
# C^n -> C^n
129168
function finite_difference!(df::StridedArray{<:Number}, f, x::StridedArray{<:Number},
130169
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
131170
fx::Union{Void,StridedArray{<:Number}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, return_type::DataType=eltype(x))
@@ -147,8 +186,8 @@ function finite_difference!(df::StridedArray{<:Number}, f, x::StridedArray{<:Num
147186
epsilon = compute_epsilon(Val{:central}, real(x[i]), epsilon_factor)
148187
df[i] = (real(f(x[i]+epsilon) - f(x[i]-epsilon)) + im*imag(f(x[i]+im*epsilon) - f(x[i]-im*epsilon))) / (2 * epsilon)
149188
end
150-
elseif fdtype == Val{:complex}
151-
error("Invalid fdtype value, Val{:complex} not implemented for complex-valued functions.")
189+
else
190+
fdtype_error(Val{:Complex})
152191
end
153192
df
154193
end
@@ -169,6 +208,8 @@ function finite_difference(f, x::T, fdtype::DataType, funtype::DataType=Val{:Rea
169208
elseif funtype == Val{:Complex}
170209
epsilon = compute_epsilon(fdtype, real(x))
171210
return finite_difference_kernel(f, x, fdtype, funtype, epsilon, f_x)
211+
else
212+
fdtype_error(funtype)
172213
end
173214
end
174215

@@ -186,7 +227,7 @@ end
186227

187228
@inline function finite_difference_kernel(f, x::Number, ::Type{Val{:forward}}, ::Type{Val{:Complex}}, epsilon::Real, fx::Union{Void,<:Number}=nothing)
188229
if typeof(fx) == Void
189-
return real((f(x[i]+epsilon) - f(x[i]))) / epsilon + im*imag((f(x[i]+im*epsilon) - fx[i])) / epsilon
230+
return real((f(x[i]+epsilon) - f(x[i]))) / epsilon + im*imag((f(x[i]+im*epsilon) - f(x[i]))) / epsilon
190231
else
191232
return real((f(x[i]+epsilon) - fx[i])) / epsilon + im*imag((f(x[i]+im*epsilon) - fx[i])) / epsilon
192233
end

src/diffeqwrappers.jl

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,90 @@
1-
function finite_difference!(df::AbstractArray{<:Real}, f, x::AbstractArray{<:Real},
2-
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:DiffEqDerivativeWrapper}},
3-
fx::Union{Void,AbstractArray{<:Real}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
1+
function finite_difference!(df::AbstractArray{<:Number}, f, x::Union{Number,AbstractArray{<:Number}},
2+
fdtype::DataType, funtype::DataType, ::Type{Val{:DiffEqDerivativeWrapper}},
3+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
44

5-
# TODO: test this one, and figure out what happens with epsilon
6-
fx1 = f.fx1
7-
if fdtype == Val{:forward}
8-
epsilon = compute_epsilon(Val{:forward}, x)
9-
f(fx, x)
10-
f(fx1, x+epsilon)
11-
@. df = (fx1 - fx) / epsilon
12-
elseif fdtype == Val{:central}
13-
epsilon = compute_epsilon(Val{:central}, x)
14-
f(fx, x-epsilon)
15-
f(fx1, x+epsilon)
16-
@. df = (fx1 - fx) / (2 * epsilon)
17-
elseif fdtype == Val{:complex}
18-
epsilon = eps(eltype(x))
19-
f(fx, f(x+im*epsilon))
20-
@. df = imag(fx) / epsilon
21-
end
5+
# TODO: optimized implementations for specific wrappers using the added DiffEq caching where appopriate
6+
7+
finite_difference!(df, f, x, fdtype, funtype, Val{:Default}, fx, epsilon, return_type)
228
df
239
end
2410

25-
# AbstractArray{T} should be OK if JacobianWrapper is provided
26-
function finite_difference_jacobian!(J::AbstractArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}}, fx::StridedArray{T}, ::Type{Val{:JacobianWrapper}}) where T<:Real
11+
function finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f, x::AbstractArray{<:Real},
12+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:JacobianWrapper}},
13+
fx::AbstractArray{<:Real}, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
14+
2715
m, n = size(J)
28-
epsilon_factor = compute_epsilon_factor(Val{:forward}, T)
16+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
2917
x1, fx1 = f.x1, f.fx1
3018
copy!(x1, x)
31-
copy!(fx1, fx)
32-
@inbounds for i in 1:n
33-
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
34-
epsilon_inv = one(T) / epsilon
35-
x1[i] += epsilon
36-
f(fx, x)
37-
f(fx1, x1)
38-
@. J[:,i] = (fx-fx1) * epsilon_inv
39-
x1[i] -= epsilon
19+
if fdtype == Val{:forward}
20+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
21+
@inbounds for i 1:n
22+
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
23+
x1[i] += epsilon
24+
f(fx1, x1)
25+
f(fx, x)
26+
@. J[:,i] = (fx1 - fx) / epsilon
27+
x1[i] -= epsilon
28+
end
29+
elseif fdtype == Val{:central}
30+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
31+
@inbounds for i 1:n
32+
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
33+
x1[i] += epsilon
34+
x[i] -= epsilon
35+
f(fx1, x1)
36+
f(fx, x)
37+
@. J[:,i] = (fx1 - fx) / (2*epsilon)
38+
x1[i] -= epsilon
39+
x[i] += epsilon
40+
end
41+
elseif fdtype == Val{:complex}
42+
x0 = Complex{eltype(x)}(x)
43+
epsilon = eps(eltype(x))
44+
@inbounds for i 1:n
45+
x0[i] += im * epsilon
46+
@. J[:,i] = imag(f(x0)) / epsilon
47+
x0[i] -= im * epsilon
48+
end
49+
else
50+
fdtype_error(Val{:Real})
4051
end
4152
J
4253
end
4354

44-
function finite_difference_jacobian!(J::AbstractArray{T}, f, x::StridedArray{T}, ::Type{Val{:central}}, fx::StridedArray{T}, ::Type{Val{:JacobianWrapper}}) where T<:Real
55+
function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
56+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:JacobianWrapper}},
57+
fx::AbstractArray{<:Number}, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, return_type::DataType=eltype(x))
58+
59+
# TODO: test this
4560
m, n = size(J)
46-
epsilon_factor = compute_epsilon_factor(Val{:central}, T)
61+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
4762
x1, fx1 = f.x1, f.fx1
4863
copy!(x1, x)
49-
copy!(fx1, fx)
50-
@inbounds for i in 1:n
51-
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
52-
epsilon_double_inv = one(T) / (2 * epsilon)
53-
x[i] += epsilon
54-
x1[i] -= epsilon
55-
f(fx, x)
56-
f(fx1, x1)
57-
@. J[:,i] = (fx-fx1) * epsilon_double_inv
58-
x[i] -= epsilon
59-
x1[i] += epsilon
64+
if fdtype == Val{:forward}
65+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
66+
@inbounds for i 1:n
67+
epsilon = compute_epsilon(Val{:forward}, real(x[i]), epsilon_factor)
68+
x1[i] += epsilon
69+
f(fx1, x1)
70+
f(fx, x)
71+
@. J[:,i] = ( real( (fx1 - fx) ) + im*imag( (fx1 - fx) ) ) / epsilon
72+
x1[i] -= epsilon
73+
end
74+
elseif fdtype == Val{:central}
75+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
76+
@inbounds for i 1:n
77+
epsilon = compute_epsilon(Val{:central}, real(x[i]), epsilon_factor)
78+
x1[i] += epsilon
79+
x[i] -= epsilon
80+
f(fx1, x1)
81+
f(fx, x)
82+
@. J[:,i] = ( real( (fx1 - fx) ) + im*imag( fx1 - fx ) ) / (2*epsilon)
83+
x1[i] -= epsilon
84+
x[i] += epsilon
85+
end
86+
else
87+
fdtype_error(Val{:Complex})
6088
end
6189
J
6290
end

src/finitediff.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ end
2222
else
2323
error("Unrecognized fdtype: must be Val{:forward} or Val{:central}.")
2424
end
25+
nothing
2526
end
2627

2728
function compute_epsilon_elemtype(epsilon, x)
@@ -34,6 +35,18 @@ function compute_epsilon_elemtype(epsilon, x)
3435
else
3536
error("Could not compute epsilon type.")
3637
end
38+
nothing
39+
end
40+
41+
function fdtype_error(funtype::DataType=Val{:Real})
42+
if funtype == Val{:Real}
43+
error("Unrecognized fdtype: valid values are Val{:forward}, Val{:central} and Val{:complex}.")
44+
elseif funtype == Val{:Complex}
45+
error("Unrecognized fdtype: valid values are Val{:forward} or Val{:central}.")
46+
else
47+
error("Unrecognized funtype: valid values are Val{:Real} or Val{:Complex}.")
48+
end
49+
nothing
3750
end
3851

3952
include("derivatives.jl")

src/jacobians.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,15 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f, x::Abstract
1515
end
1616

1717
function finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f, x::AbstractArray{<:Real},
18-
fdtype::DataType, ::Type{Val{:Real}}, wrappertype::DataType=Val{:Default},
18+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
1919
fx::Union{Void,AbstractArray{<:Real}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, returntype=eltype(x))
2020

2121
# TODO: test and rework this
2222
m, n = size(J)
2323
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
2424
if fdtype == Val{:forward}
2525
if typeof(fx) == Void
26-
if wrappertype==Val{:Default}
27-
fx = f.(x)
28-
elseif wrappertype==Val{:DiffEqJacobianWrapper}
29-
fx = f(x)
30-
else
31-
error("Unrecognized wrappertype: must be Val{:Default} or Val{:DiffEqJacobianWrapper}.")
32-
end
26+
fx = f.(x)
3327
end
3428
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
3529
shifted_x = copy(x)

0 commit comments

Comments
 (0)