@@ -5,11 +5,11 @@ function finite_difference_derivative(f, x::T, fdtype::DataType=Val{:central},
5
5
returntype:: DataType = eltype (x), f_x:: Union{Void,T} = nothing ) where T<: Number
6
6
7
7
epsilon = compute_epsilon (fdtype, real (x))
8
- if fdtype == Val{:forward }
8
+ if fdtype== Val{:forward }
9
9
return (f (x+ epsilon) - f (x)) / epsilon
10
- elseif fdtype == Val{:central }
10
+ elseif fdtype== Val{:central }
11
11
return (f (x+ epsilon) - f (x- epsilon)) / (2 * epsilon)
12
- elseif fdtype == Val{:complex } && returntype == Val{ :Real }
12
+ elseif fdtype== Val{:complex } && returntype< :Real
13
13
return imag (f (x+ im* epsilon)) / epsilon
14
14
end
15
15
fdtype_error (returntype)
65
65
#=
66
66
Multi-point implementations of scalar derivatives for efficiency.
67
67
=#
68
- struct DerivativeCache{CacheType1, CacheType2, fdtype, RealOrComplex }
68
+ struct DerivativeCache{CacheType1, CacheType2, fdtype, returntype }
69
69
fx :: CacheType1
70
70
epsilon :: CacheType2
71
71
end
72
72
73
73
function DerivativeCache (
74
- x :: AbstractArray{<:Number} ,
75
- fx :: Union{Void,AbstractArray{<:Number}} = nothing ,
76
- epsilon :: Union{Void,AbstractArray{<:Number}} = nothing ,
77
- fdtype :: DataType = Val{:central },
78
- RealOrComplex :: DataType =
79
- fdtype== Val{:complex } ? Val{:Real } : eltype (x) <: Complex ?
80
- Val{:Complex } : Val{:Real }
81
- )
82
-
83
- if fdtype == Val{:complex } && RealOrComplex == Val{:Complex }
84
- fdtype_error (Val{:Complex })
74
+ x :: AbstractArray{<:Number} ,
75
+ fx :: Union{Void,AbstractArray{<:Number}} = nothing ,
76
+ epsilon :: Union{Void,AbstractArray{<:Number}} = nothing ,
77
+ fdtype :: DataType = Val{:central },
78
+ returntype :: DataType = eltype (x))
79
+
80
+ if fdtype== Val{:complex } && ! (eltype (returntype)<: Real )
81
+ fdtype_error (returntype)
85
82
end
86
83
87
- if fdtype != Val{:forward }
88
- warn (" Pre-computed function values are only useful for fdtype == Val{:forward}." )
84
+ if fdtype!= Val{:forward } && typeof (fx) != Void
85
+ warn (" Pre-computed function values are only useful for fdtype== Val{:forward}." )
89
86
_fx = nothing
90
87
else
91
88
# more runtime sanity checks?
92
89
_fx = fx
93
90
end
94
91
95
- if typeof (epsilon) == Void
92
+ if typeof (epsilon)!= Void && typeof (x)<: StridedArray && typeof (fx)<: Union{Void,StridedArray} && 1 == 2
93
+ warn (" StridedArrays don't benefit from pre-allocating epsilon." )
94
+ _epsilon = nothing
95
+ elseif typeof (epsilon)!= Void && fdtype== Val{:complex }
96
+ warn (" Val{:complex} makes the epsilon array redundant." )
96
97
_epsilon = nothing
97
98
else
98
- if typeof (x)<: StridedArray && typeof (fx)<: Union{Void,StridedArray}
99
- warn (" StridedArrays don't benefit from pre-allocating epsilon." )
100
- _epsilon = nothing
101
- elseif fdtype == Val{:complex }
102
- warn (" Val{:complex} makes the epsilon array redundant." )
103
- _epsilon = nothing
104
- else
105
- epsilon_elemtype = compute_epsilon_elemtype (epsilon, x)
106
- if typeof (epsilon) == Void || eltype (epsilon) != epsilon_elemtype
107
- epsilon = zeros (epsilon_elemtype, size (x))
108
- end
109
- epsilon_factor = compute_epsilon_factor (fdtype, real (eltype (x)))
110
- @. epsilon = compute_epsilon (fdtype, real (x), epsilon_factor)
111
- _epsilon = epsilon
99
+ if typeof (epsilon)== Void || eltype (epsilon)!= real (eltype (x))
100
+ epsilon = zeros (real (eltype (x)), size (x))
112
101
end
102
+ epsilon_factor = compute_epsilon_factor (fdtype, real (eltype (x)))
103
+ @. epsilon = compute_epsilon (fdtype, real (x), epsilon_factor)
104
+ _epsilon = epsilon
113
105
end
114
- DerivativeCache {typeof(_fx),typeof(_epsilon),fdtype,RealOrComplex } (_fx,_epsilon)
106
+ DerivativeCache {typeof(_fx),typeof(_epsilon),fdtype,returntype } (_fx,_epsilon)
115
107
end
116
108
117
109
#=
118
110
Compute the derivative df of a scalar-valued map f at a collection of points x.
119
111
=#
120
- function finite_difference_derivative (f, x :: AbstractArray{<:Number} , fdtype :: DataType = Val{ :central },
121
- RealOrComplex :: DataType =
122
- fdtype == Val{ :complex } ? Val{ :Real } : eltype (x) <: Complex ?
123
- Val{ :Complex } : Val{:Real },
124
- fx :: Union{Void,AbstractArray{<:Number}} = nothing ,
125
- epsilon :: Union{Void,AbstractArray{<:Real}} = nothing ,
126
- return_type :: DataType = eltype (x) )
127
-
128
- df = zeros (return_type , size (x))
129
- finite_difference_derivative! (df, f, x, fdtype, RealOrComplex , fx, epsilon, return_type )
112
+ function finite_difference_derivative (
113
+ f,
114
+ x :: AbstractArray{<:Number} ,
115
+ fdtype :: DataType = Val{:central },
116
+ returntype :: DataType = eltype (x), # return type of f
117
+ fx :: Union{Void,AbstractArray{<:Number}} = nothing ,
118
+ epsilon :: Union{Void,AbstractArray{<:Real}} = nothing )
119
+
120
+ df = zeros (returntype , size (x))
121
+ finite_difference_derivative! (df, f, x, fdtype, returntype , fx, epsilon)
130
122
end
131
123
132
- function finite_difference_derivative! (df:: AbstractArray{<:Number} , f,
133
- x:: AbstractArray{<:Number} , fdtype:: DataType = Val{:central },
134
- RealOrComplex :: DataType =
135
- fdtype== Val{:complex } ? Val{:Real } : eltype (x) <: Complex ?
136
- Val{:Complex } : Val{:Real },
137
- fx:: Union{Void,AbstractArray{<:Number}} = nothing ,
138
- epsilon:: Union{Void,AbstractArray{<:Real}} = nothing , return_type:: DataType = eltype (x))
139
-
140
- cache = DerivativeCache (x, fx, epsilon, fdtype, RealOrComplex)
141
- _finite_difference_derivative! (df, f, x, cache)
124
+ function finite_difference_derivative! (
125
+ df :: AbstractArray{<:Number} ,
126
+ f,
127
+ x :: AbstractArray{<:Number} ,
128
+ fdtype :: DataType = Val{:central },
129
+ returntype :: DataType = eltype (x),
130
+ fx :: Union{Void,AbstractArray{<:Number}} = nothing ,
131
+ epsilon :: Union{Void,AbstractArray{<:Real}} = nothing )
132
+
133
+ cache = DerivativeCache (x, fx, epsilon, fdtype, returntype)
134
+ finite_difference_derivative! (df, f, x, cache)
142
135
end
143
136
144
137
function finite_difference_derivative! (df:: AbstractArray{<:Number} , f, x:: AbstractArray{<:Number} ,
145
- cache:: DerivativeCache{T1,T2,fdtype,RealOrComplex} ) where {T1,T2,fdtype,RealOrComplex}
146
-
147
- _finite_difference_derivative! (df, f, x, cache)
148
- end
149
-
150
- function _finite_difference_derivative! (df:: AbstractArray{<:Real} , f, x:: AbstractArray{<:Real} ,
151
- cache:: DerivativeCache{T1,T2,fdtype,Val{:Real}} ) where {T1,T2,fdtype}
138
+ cache:: DerivativeCache{T1,T2,fdtype,returntype} ) where {T1,T2,fdtype,returntype}
152
139
153
140
fx, epsilon = cache. fx, cache. epsilon
154
141
if fdtype == Val{:forward }
@@ -159,29 +146,11 @@ function _finite_difference_derivative!(df::AbstractArray{<:Real}, f, x::Abstrac
159
146
end
160
147
elseif fdtype == Val{:central }
161
148
@. df = (f (x+ epsilon) - f (x- epsilon)) / (2 * epsilon)
162
- elseif fdtype == Val{:complex }
163
- epsilon_elemtype = compute_epsilon_elemtype (nothing , x)
164
- epsilon_complex = eps (epsilon_elemtype)
149
+ elseif fdtype == Val{:complex } && returntype<: Real
150
+ epsilon_complex = eps (eltype (x))
165
151
@. df = imag (f (x+ im* epsilon_complex)) / epsilon_complex
166
152
else
167
- fdtype_error (Val{:Real })
168
- end
169
- df
170
- end
171
-
172
- function _finite_difference_derivative! (df:: AbstractArray{<:Number} , f, x:: AbstractArray{<:Number} ,
173
- cache:: DerivativeCache{T1,T2,fdtype,Val{:Complex}} ) where {T1,T2,fdtype}
174
-
175
- fx, epsilon = cache. fx, cache. epsilon
176
- if fdtype == Val{:forward }
177
- if typeof (fx) == Void
178
- fx = f .(x)
179
- end
180
- @. df = real ((f (x+ epsilon) - fx)) / epsilon + im* imag ((f (x+ epsilon) - fx)) / epsilon
181
- elseif fdtype == Val{:central }
182
- @. df = real (f (x+ epsilon) - f (x- epsilon)) / (2 * epsilon) + im* imag (f (x+ epsilon) - f (x- epsilon)) / (2 * epsilon)
183
- else
184
- fdtype_error (Val{:Complex })
153
+ fdtype_error (returntype)
185
154
end
186
155
df
187
156
end
@@ -191,14 +160,13 @@ Optimized implementations for StridedArrays.
191
160
Essentially, the only difference between these and the AbstractArray case
192
161
is that here we can compute the epsilon one by one in local variables and avoid caching it.
193
162
=#
194
- function _finite_difference_derivative! (df:: StridedArray{<:Real} , f, x:: StridedArray{<:Real} ,
195
- cache:: DerivativeCache{T1,T2,fdtype,Val{:Real}} ) where {T1,T2,fdtype}
163
+ function _finite_difference_derivative! (df:: StridedArray , f, x:: StridedArray ,
164
+ cache:: DerivativeCache{T1,T2,fdtype,returntype} ) where {T1,T2,fdtype,returntype }
196
165
197
- epsilon_elemtype = compute_epsilon_elemtype ( nothing , x )
166
+ epsilon_factor = compute_epsilon_factor (fdtype, real ( eltype (x)) )
198
167
if fdtype == Val{:forward }
199
168
fx = cache. fx
200
- epsilon_factor = compute_epsilon_factor (Val{:forward }, epsilon_elemtype)
201
- @inbounds for i in 1 : length (x)
169
+ @inbounds for i ∈ eachindex (x)
202
170
epsilon = compute_epsilon (Val{:forward }, x[i], epsilon_factor)
203
171
x_plus = x[i] + epsilon
204
172
if typeof (fx) == Void
@@ -208,48 +176,19 @@ function _finite_difference_derivative!(df::StridedArray{<:Real}, f, x::StridedA
208
176
end
209
177
end
210
178
elseif fdtype == Val{:central }
211
- epsilon_factor = compute_epsilon_factor (Val{:central }, epsilon_elemtype)
212
- @inbounds for i in 1 : length (x)
179
+ @inbounds for i ∈ eachindex (x)
213
180
epsilon = compute_epsilon (Val{:central }, x[i], epsilon_factor)
214
181
epsilon_double_inv = one (typeof (epsilon)) / (2 * epsilon)
215
182
x_plus, x_minus = x[i]+ epsilon, x[i]- epsilon
216
183
df[i] = (f (x_plus) - f (x_minus)) * epsilon_double_inv
217
184
end
218
185
elseif fdtype == Val{:complex }
219
186
epsilon_complex = eps (eltype (x))
220
- @inbounds for i in 1 : length (x)
187
+ @inbounds for i ∈ eachindex (x)
221
188
df[i] = imag (f (x[i]+ im* epsilon_complex)) / epsilon_complex
222
189
end
223
190
else
224
- fdtype_error (Val{:Real })
225
- end
226
- df
227
- end
228
-
229
- function _finite_difference_derivative! (df:: StridedArray{<:Number} , f, x:: StridedArray{<:Number} ,
230
- cache:: DerivativeCache{T1,T2,fdtype,Val{:Complex}} ) where {T1,T2,fdtype}
231
-
232
- epsilon_elemtype = compute_epsilon_elemtype (nothing , x)
233
- if fdtype == Val{:forward }
234
- fx = cache. fx
235
- epsilon_factor = compute_epsilon_factor (Val{:forward }, epsilon_elemtype)
236
- @inbounds for i in 1 : length (x)
237
- epsilon = compute_epsilon (Val{:forward }, real (x[i]), epsilon_factor)
238
- if typeof (fx) == Void
239
- fxi = f (x[i])
240
- else
241
- fxi = fx[i]
242
- end
243
- df[i] = ( real ( f (x[i]+ epsilon) - fxi ) + im* imag ( f (x[i]+ epsilon) - fxi ) ) / epsilon
244
- end
245
- elseif fdtype == Val{:central }
246
- epsilon_factor = compute_epsilon_factor (Val{:central }, epsilon_elemtype)
247
- @inbounds for i in 1 : length (x)
248
- epsilon = compute_epsilon (Val{:central }, real (x[i]), epsilon_factor)
249
- df[i] = ( real ( f (x[i]+ epsilon) - f (x[i]- epsilon) ) + im* imag ( f (x[i]+ epsilon) - f (x[i]- epsilon) ) ) / (2 * epsilon)
250
- end
251
- else
252
- fdtype_error (Val{:Complex })
191
+ fdtype_error (returntype)
253
192
end
254
193
df
255
194
end
0 commit comments