@@ -42,7 +42,7 @@ function finite_difference!(df::AbstractArray{<:Real}, f, x::AbstractArray{<:Rea
42
42
epsilon_complex = eps (epsilon_elemtype)
43
43
@. df = imag (f (x+ im* epsilon_complex)) / epsilon_complex
44
44
else
45
- error ( " Unrecognized fdtype: valid values are Val{:forward}, Val{:central} and Val{:complex}. " )
45
+ fdtype_error ( Val{:Real } )
46
46
end
47
47
df
48
48
end
@@ -61,7 +61,6 @@ function finite_difference!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:N
61
61
end
62
62
end
63
63
if fdtype == Val{:forward }
64
- @show typeof (x)
65
64
epsilon_factor = compute_epsilon_factor (Val{:forward }, eltype (epsilon))
66
65
@. epsilon = compute_epsilon (Val{:forward }, real (x), epsilon_factor)
67
66
if typeof (fx) == Void
@@ -72,8 +71,8 @@ function finite_difference!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:N
72
71
epsilon_factor = compute_epsilon_factor (Val{:central }, eltype (epsilon))
73
72
@. epsilon = compute_epsilon (Val{:central }, real (x), epsilon_factor)
74
73
@. df = real (f (x+ epsilon) - f (x- epsilon)) / (2 * epsilon) + im* imag (f (x+ im* epsilon) - f (x- epsilon)) / (2 * epsilon)
75
- elseif fdtype == Val{ :complex }
76
- error ( " Invalid fdtype value, Val{:complex} not implemented for complex-valued functions. " )
74
+ else
75
+ fdtype_error ( Val{:Complex } )
77
76
end
78
77
df
79
78
end
82
81
#=
83
82
Optimized implementations for StridedArrays.
84
83
=#
85
- function finite_difference! (df:: StridedArray{<:Real} , f, x:: StridedArray{<:Real} ,
86
- :: Type{Val{:central}} , :: Type{Val{:Real}} , :: Type{Val{:Default}} ,
84
+ # for R -> R^n
85
+ function finite_difference! (df:: StridedArray{<:Real} , f, x:: Real ,
86
+ fdtype:: DataType , :: Type{Val{:Real}} , :: Type{Val{:Default}} ,
87
87
fx:: Union{Void,StridedArray{<:Real}} = nothing , epsilon:: Union{Void,StridedArray{<:Real}} = nothing , return_type:: DataType = eltype (x))
88
88
89
89
epsilon_elemtype = compute_epsilon_elemtype (epsilon, x)
90
- epsilon_factor = compute_epsilon_factor (Val{:central }, epsilon_elemtype)
91
- @inbounds for i in 1 : length (x)
92
- epsilon = compute_epsilon (Val{:central }, x[i], epsilon_factor)
93
- epsilon_double_inv = one (typeof (epsilon)) / (2 * epsilon)
94
- x_plus, x_minus = x[i]+ epsilon, x[i]- epsilon
95
- df[i] = (f (x_plus) - f (x_minus)) * epsilon_double_inv
90
+ if fdtype == Val{:forward }
91
+ epsilon = compute_epsilon (Val{:forward }, x)
92
+ if typeof (fx) == Void
93
+ df .= (f (x+ epsilon) - f (x)) / epsilon
94
+ else
95
+ df .= (f (x+ epsilon) - fx) / epsilon
96
+ end
97
+ elseif fdtype == Val{:central }
98
+ epsilon = compute_epsilon (Val{:central }, x)
99
+ df .= (f (x+ epsilon) - f (x- epsilon)) / (2 * epsilon)
100
+ elseif fdtype == Val{:complex }
101
+ epsilon = eps (eltype (x))
102
+ df .= imag (f (x+ im* epsilon)) / epsilon
103
+ else
104
+ fdtype_error (Val{:Real })
96
105
end
97
106
df
98
107
end
99
108
109
+ # for R^n -> R^n
100
110
function finite_difference! (df:: StridedArray{<:Real} , f, x:: StridedArray{<:Real} ,
101
- :: Type{Val{:forward}} , :: Type{Val{:Real}} , :: Type{Val{:Default}} ,
111
+ fdtype :: DataType , :: Type{Val{:Real}} , :: Type{Val{:Default}} ,
102
112
fx:: Union{Void,StridedArray{<:Real}} = nothing , epsilon:: Union{Void,StridedArray{<:Real}} = nothing , return_type:: DataType = eltype (x))
103
113
104
114
epsilon_elemtype = compute_epsilon_elemtype (epsilon, x)
105
- epsilon_factor = compute_epsilon_factor (Val{:forward }, epsilon_elemtype)
106
- @inbounds for i in 1 : length (x)
107
- epsilon = compute_epsilon (Val{:forward }, x[i], epsilon_factor)
108
- x_plus = x[i] + epsilon
109
- if typeof (fx) == Void
110
- df[i] = (f (x_plus) - f (x[i])) / epsilon
111
- else
112
- df[i] = (f (x_plus) - fx[i]) / epsilon
115
+ if fdtype == Val{:forward }
116
+ epsilon_factor = compute_epsilon_factor (Val{:forward }, epsilon_elemtype)
117
+ @inbounds for i in 1 : length (x)
118
+ epsilon = compute_epsilon (Val{:forward }, x[i], epsilon_factor)
119
+ x_plus = x[i] + epsilon
120
+ if typeof (fx) == Void
121
+ df[i] = (f (x_plus) - f (x[i])) / epsilon
122
+ else
123
+ df[i] = (f (x_plus) - fx[i]) / epsilon
124
+ end
125
+ end
126
+ elseif fdtype == Val{:central }
127
+ epsilon_factor = compute_epsilon_factor (Val{:central }, epsilon_elemtype)
128
+ @inbounds for i in 1 : length (x)
129
+ epsilon = compute_epsilon (Val{:central }, x[i], epsilon_factor)
130
+ epsilon_double_inv = one (typeof (epsilon)) / (2 * epsilon)
131
+ x_plus, x_minus = x[i]+ epsilon, x[i]- epsilon
132
+ df[i] = (f (x_plus) - f (x_minus)) * epsilon_double_inv
133
+ end
134
+ elseif fdtype == Val{:complex }
135
+ epsilon_complex = eps (eltype (x))
136
+ @inbounds for i in 1 : length (x)
137
+ df[i] = imag (f (x[i]+ im* epsilon_complex)) / epsilon_complex
113
138
end
139
+ else
140
+ fdtype_error (Val{:Real })
114
141
end
115
142
df
116
143
end
117
144
118
- function finite_difference! (df:: StridedArray{<:Real} , f, x:: StridedArray{<:Real} ,
119
- :: Type{Val{:complex}} , :: Type{Val{:Real}} , :: Type{Val{:Default}} ,
120
- fx:: Union{Void,StridedArray{<:Real}} = nothing , epsilon:: Union{Void,StridedArray{<:Real}} = nothing , return_type:: DataType = eltype (x))
145
+ # C -> C^n
146
+ function finite_difference! (df:: StridedArray{<:Number} , f, x:: Number ,
147
+ fdtype:: DataType , :: Type{Val{:Complex}} , :: Type{Val{:Default}} ,
148
+ fx:: Union{Void,StridedArray{<:Number}} = nothing , epsilon:: Union{Void,StridedArray{<:Real}} = nothing , return_type:: DataType = eltype (x))
121
149
122
- epsilon_complex = eps (eltype (x))
123
- @inbounds for i in 1 : length (x)
124
- df[i] = imag (f (x[i]+ im* epsilon_complex)) / epsilon_complex
150
+ epsilon_elemtype = compute_epsilon_elemtype (epsilon, x)
151
+ if fdtype == Val{:forward }
152
+ epsilon = compute_epsilon (Val{:forward }, real (x[i]))
153
+ if typeof (fx) == Void
154
+ df .= ( real ( f (x+ epsilon) - f (x) ) + im* imag ( f (x+ im* epsilon) - f (x) ) ) / epsilon
155
+ else
156
+ df .= ( real ( f (x+ epsilon) - fx ) + im* imag ( f (x+ im* epsilon) - fx )) / epsilon
157
+ end
158
+ elseif fdtype == Val{:central }
159
+ epsilon = compute_epsilon (Val{:central }, real (x[i]))
160
+ df .= (real (f (x+ epsilon) - f (x- epsilon)) + im* imag (f (x+ im* epsilon) - f (x- im* epsilon))) / (2 * epsilon)
161
+ else
162
+ fdtype_error (Val{:Complex })
125
163
end
126
164
df
127
165
end
128
166
167
+ # C^n -> C^n
129
168
function finite_difference! (df:: StridedArray{<:Number} , f, x:: StridedArray{<:Number} ,
130
169
fdtype:: DataType , :: Type{Val{:Complex}} , :: Type{Val{:Default}} ,
131
170
fx:: Union{Void,StridedArray{<:Number}} = nothing , epsilon:: Union{Void,StridedArray{<:Real}} = nothing , return_type:: DataType = eltype (x))
@@ -147,8 +186,8 @@ function finite_difference!(df::StridedArray{<:Number}, f, x::StridedArray{<:Num
147
186
epsilon = compute_epsilon (Val{:central }, real (x[i]), epsilon_factor)
148
187
df[i] = (real (f (x[i]+ epsilon) - f (x[i]- epsilon)) + im* imag (f (x[i]+ im* epsilon) - f (x[i]- im* epsilon))) / (2 * epsilon)
149
188
end
150
- elseif fdtype == Val{ :complex }
151
- error ( " Invalid fdtype value, Val{:complex} not implemented for complex-valued functions. " )
189
+ else
190
+ fdtype_error ( Val{:Complex } )
152
191
end
153
192
df
154
193
end
@@ -169,6 +208,8 @@ function finite_difference(f, x::T, fdtype::DataType, funtype::DataType=Val{:Rea
169
208
elseif funtype == Val{:Complex }
170
209
epsilon = compute_epsilon (fdtype, real (x))
171
210
return finite_difference_kernel (f, x, fdtype, funtype, epsilon, f_x)
211
+ else
212
+ fdtype_error (funtype)
172
213
end
173
214
end
174
215
186
227
187
228
@inline function finite_difference_kernel (f, x:: Number , :: Type{Val{:forward}} , :: Type{Val{:Complex}} , epsilon:: Real , fx:: Union{Void,<:Number} = nothing )
188
229
if typeof (fx) == Void
189
- return real ((f (x[i]+ epsilon) - f (x[i]))) / epsilon + im* imag ((f (x[i]+ im* epsilon) - fx [i])) / epsilon
230
+ return real ((f (x[i]+ epsilon) - f (x[i]))) / epsilon + im* imag ((f (x[i]+ im* epsilon) - f (x [i]) )) / epsilon
190
231
else
191
232
return real ((f (x[i]+ epsilon) - fx[i])) / epsilon + im* imag ((f (x[i]+ im* epsilon) - fx[i])) / epsilon
192
233
end
0 commit comments