Skip to content

Commit 0e9ad46

Browse files
committed
Jacobians of complex-valued callables should work, at least for StridedArrays.
1 parent e520985 commit 0e9ad46

File tree

2 files changed

+13
-26
lines changed

2 files changed

+13
-26
lines changed

src/finitediff.jl

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ end
221221
# TODO: optimized implementations for DiffEq wrappers
222222

223223
#=
224-
Compute the derivative df of a real-valued callable f on a collection of points x.
224+
Compute the derivative df of a callable f on a collection of points x.
225225
Single point implementations.
226226
=#
227227
function finite_difference(f, x::T, fdtype::DataType, funtype::DataType=Val{:Real}, f_x::Union{Void,T}=nothing) where T<:Number
@@ -428,51 +428,37 @@ function finite_difference_jacobian!(J::StridedMatrix{<:Number}, f, x::StridedAr
428428
m, n = size(J)
429429
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
430430
if fdtype == Val{:forward}
431-
epsilon_factor = compute_epsilon_factor(Val{:forward}, eltype(x))
431+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
432432
@inbounds for i in 1:n
433-
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
433+
epsilon = compute_epsilon(Val{:forward}, real(x[i]), epsilon_factor)
434434
epsilon_inv = one(returntype) / epsilon
435435
for j in 1:m
436436
if i==j
437437
if typeof(fx) == Void
438-
J[j,i] = (f(x[j]+epsilon) - f(x[j])) * epsilon_inv
438+
J[j,i] = ( real( f(x[j]+epsilon) - f(x[j]) ) + im*imag( f(x[j]+im*epsilon) - f(x[j]) ) ) * epsilon_inv
439439
else
440-
if typeof(fx) == Void
441-
J[j,i] = (f(x[j]+epsilon) - f(x[j])) * epsilon_inv
442-
else
443-
J[j,i] = (f(x[j]+epsilon) - fx[j]) * epsilon_inv
444-
end
440+
J[j,i] = ( real( f(x[j]+epsilon) - fx[j] ) + im*imag( f(x[j]+im*epsilon) - fx[j] ) ) * epsilon_inv
445441
end
446442
else
447443
J[j,i] = zero(returntype)
448444
end
449445
end
450446
end
451447
elseif fdtype == Val{:central}
452-
epsilon_factor = compute_epsilon_factor(Val{:central}, eltype(x))
448+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
453449
@inbounds for i in 1:n
454-
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
450+
epsilon = compute_epsilon(Val{:central}, real(x[i]), epsilon_factor)
455451
epsilon_double_inv = one(returntype) / (2 * epsilon)
456452
for j in 1:m
457453
if i==j
458-
J[j,i] = (f(x[j]+epsilon) - f(x[j]-epsilon)) * epsilon_double_inv
459-
else
460-
J[j,i] = zero(returntype)
461-
end
462-
end
463-
end
464-
elseif fdtype == Val{:complex}
465-
epsilon = eps(epsilon_elemtype)
466-
epsilon_inv = one(epsilon_elemtype) / epsilon
467-
@inbounds for i in 1:n
468-
for j in 1:m
469-
if i==j
470-
J[j,i] = imag(f(x[j]+im*epsilon)) * epsilon_inv
454+
J[j,i] = ( real( f(x[j]+epsilon)-f(x[j]-epsilon) ) + im*imag( f(x[j]+im*epsilon) - f(x[j]-im*epsilon) ) ) * epsilon_double_inv
471455
else
472456
J[j,i] = zero(returntype)
473457
end
474458
end
475459
end
460+
else
461+
error("Unrecognized fdtype: must be Val{:forward} or Val{:central}.")
476462
end
477463
J
478464
end

test/finitedifftests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ J_ref = diagm(df_ref)
8181
@test err_func(DiffEqDiffTools.finite_difference(f, x, Val{:central}, Val{:Complex}, Val{:Default}, y, epsilon), df_ref) < 1e-8
8282

8383
# Jacobian tests for complex-valued callables
84-
#@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x, Val{:forward}, Val{:Complex}, Val{:Default}), J_ref) < 1e-4
85-
#@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x, Val{:central}, Val{:Complex}, Val{:Default}), J_ref) < 1e-8
84+
@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x, Val{:forward}, Val{:Complex}, Val{:Default}), J_ref) < 1e-4
85+
@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x, Val{:central}, Val{:Complex}, Val{:Default}), J_ref) < 1e-8
86+
8687
# StridedArray tests end here

0 commit comments

Comments
 (0)