Skip to content

Commit f788176

Browse files
committed
Add the option to get the derivatives from the Jacobians without extra computation.
1 parent 4803b91 commit f788176

File tree

2 files changed

+75
-7
lines changed

2 files changed

+75
-7
lines changed

src/jacobians.jl

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
# Compute the Jacobian matrix of a real-valued callable f.
22
function finite_difference_jacobian(f, x::AbstractArray{<:Number},
33
fdtype::DataType=Val{:central}, funtype::DataType=Val{:Real}, wrappertype::DataType=Val{:Default},
4-
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, returntype=eltype(x))
4+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, returntype=eltype(x),
5+
df::Union{Void,AbstractArray{<:Number}}=nothing)
56

67
J = zeros(returntype, length(x), length(x))
7-
finite_difference_jacobian!(J, f, x, fdtype, funtype, wrappertype, fx, epsilon, returntype)
8+
finite_difference_jacobian!(J, f, x, fdtype, funtype, wrappertype, fx, epsilon, returntype, df)
89
end
910

10-
function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
11+
function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, df::AbstractVector, f, x::AbstractArray{<:Number},
1112
fdtype::DataType=Val{:central}, funtype::DataType=Val{:Real}, wrappertype::DataType=Val{:Default},
1213
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Number}}=nothing, returntype=eltype(x))
1314

14-
finite_difference_jacobian!(J, f, x, fdtype, funtype, wrappertype, fx, epsilon, returntype)
15+
finite_difference_jacobian!(J, f, x, fdtype, funtype, wrappertype, fx, epsilon, returntype, df)
16+
end
17+
18+
function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
19+
fdtype::DataType=Val{:central}, funtype::DataType=Val{:Real}, wrappertype::DataType=Val{:Default},
20+
fx::Union{Void,AbstractArray{<:Number}}=nothing, epsilon::Union{Void,AbstractArray{<:Number}}=nothing, returntype=eltype(x),
21+
df::Union{Void,AbstractArray{<:Number}}=nothing)
22+
23+
finite_difference_jacobian!(J, f, x, fdtype, funtype, wrappertype, fx, epsilon, returntype, df)
1524
end
1625

1726
function finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f, x::AbstractArray{<:Real},
1827
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
19-
fx::Union{Void,AbstractArray{<:Real}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, returntype=eltype(x))
28+
fx::Union{Void,AbstractArray{<:Real}}=nothing, epsilon::Union{Void,AbstractArray{<:Real}}=nothing, returntype=eltype(x),
29+
df::Union{Void,AbstractArray{<:Real}}=nothing)
2030

2131
# TODO: test and rework this
2232
m, n = size(J)
@@ -48,12 +58,16 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f, x::AbstractAr
4858
else
4959
error("Unrecognized fdtype: must be Val{:forward} or Val{:central}.")
5060
end
61+
if typeof(df) != Void
62+
df .= diag(J)
63+
end
5164
J
5265
end
5366

5467
function finite_difference_jacobian!(J::StridedMatrix{<:Real}, f, x::StridedArray{<:Real},
5568
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
56-
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, returntype=eltype(x))
69+
fx::Union{Void,StridedArray{<:Real}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, returntype=eltype(x),
70+
df::Union{Void,StridedArray{<:Real}}=nothing)
5771

5872
m, n = size(J)
5973
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
@@ -73,6 +87,9 @@ function finite_difference_jacobian!(J::StridedMatrix{<:Real}, f, x::StridedArra
7387
J[j,i] = (f(x[j]+epsilon) - fx[j]) * epsilon_inv
7488
end
7589
end
90+
if typeof(df) != Void
91+
df[i] = J[j,i]
92+
end
7693
else
7794
J[j,i] = zero(returntype)
7895
end
@@ -86,6 +103,9 @@ function finite_difference_jacobian!(J::StridedMatrix{<:Real}, f, x::StridedArra
86103
for j in 1:m
87104
if i==j
88105
J[j,i] = (f(x[j]+epsilon) - f(x[j]-epsilon)) * epsilon_double_inv
106+
if typeof(df) != Void
107+
df[i] = J[j,i]
108+
end
89109
else
90110
J[j,i] = zero(returntype)
91111
end
@@ -98,6 +118,9 @@ function finite_difference_jacobian!(J::StridedMatrix{<:Real}, f, x::StridedArra
98118
for j in 1:m
99119
if i==j
100120
J[j,i] = imag(f(x[j]+im*epsilon)) * epsilon_inv
121+
if typeof(df) != Void
122+
df[i] = J[j,i]
123+
end
101124
else
102125
J[j,i] = zero(returntype)
103126
end
@@ -109,7 +132,8 @@ end
109132

110133
function finite_difference_jacobian!(J::StridedMatrix{<:Number}, f, x::StridedArray{<:Number},
111134
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
112-
fx::Union{Void,StridedArray{<:Number}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, returntype=eltype(x))
135+
fx::Union{Void,StridedArray{<:Number}}=nothing, epsilon::Union{Void,StridedArray{<:Real}}=nothing, returntype=eltype(x),
136+
df::Union{Void,StridedArray{<:Number}}=nothing)
113137

114138
# TODO: finish this
115139
m, n = size(J)
@@ -126,6 +150,9 @@ function finite_difference_jacobian!(J::StridedMatrix{<:Number}, f, x::StridedAr
126150
else
127151
J[j,i] = ( real( f(x[j]+epsilon) - fx[j] ) + im*imag( f(x[j]+im*epsilon) - fx[j] ) ) * epsilon_inv
128152
end
153+
if typeof(df) != Void
154+
df[i] = J[j,i]
155+
end
129156
else
130157
J[j,i] = zero(returntype)
131158
end
@@ -139,6 +166,9 @@ function finite_difference_jacobian!(J::StridedMatrix{<:Number}, f, x::StridedAr
139166
for j in 1:m
140167
if i==j
141168
J[j,i] = ( real( f(x[j]+epsilon)-f(x[j]-epsilon) ) + im*imag( f(x[j]+im*epsilon) - f(x[j]-im*epsilon) ) ) * epsilon_double_inv
169+
if typeof(df) != Void
170+
df[i] = J[j,i]
171+
end
142172
else
143173
J[j,i] = zero(returntype)
144174
end

test/finitedifftests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,28 @@ end
6666
@test err_func(DiffEqDiffTools.finite_difference_jacobian!(J, sin, x, Val{:complex}, Val{:Real}, Val{:Default}, y, epsilon), J_ref) < 1e-15
6767
end
6868

69+
# Jacobian tests w/ derivatives (real-valued callables)
70+
@time @testset "Jacobian StridedArray real-valued derivative tests" begin
71+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:forward})
72+
@test err_func(df, df_ref) < 1e-4
73+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:central})
74+
@test err_func(df, df_ref) < 1e-8
75+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:complex})
76+
@test err_func(df, df_ref) < 1e-15
77+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:forward}, Val{:Real}, Val{:Default}, y)
78+
@test err_func(df, df_ref) < 1e-4
79+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:central}, Val{:Real}, Val{:Default}, y)
80+
@test err_func(df, df_ref) < 1e-8
81+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:complex}, Val{:Real}, Val{:Default}, y)
82+
@test err_func(df, df_ref) < 1e-15
83+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:forward}, Val{:Real}, Val{:Default}, y, epsilon)
84+
@test err_func(df, df_ref) < 1e-4
85+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:central}, Val{:Real}, Val{:Default}, y, epsilon)
86+
@test err_func(df, df_ref) < 1e-8
87+
DiffEqDiffTools.finite_difference_jacobian!(J, df, sin, x, Val{:complex}, Val{:Real}, Val{:Default}, y, epsilon)
88+
@test err_func(df, df_ref) < 1e-15
89+
end
90+
6991
# derivative tests for complex-valued callables
7092
x = x + im*x
7193
f(x) = cos(real(x)) + im*sin(imag(x))
@@ -116,4 +138,20 @@ end
116138
@test err_func(DiffEqDiffTools.finite_difference_jacobian!(J, f, x, Val{:forward}, Val{:Complex}, Val{:Default}, y, epsilon), J_ref) < 1e-4
117139
@test err_func(DiffEqDiffTools.finite_difference_jacobian!(J, f, x, Val{:central}, Val{:Complex}, Val{:Default}, y, epsilon), J_ref) < 1e-8
118140
end
141+
142+
# Jacobian tests w/ derivatives (real-valued callables)
143+
@time @testset "Jacobian StridedArray complex-valued derivative tests" begin
144+
DiffEqDiffTools.finite_difference_jacobian!(J, df, f, x, Val{:forward}, Val{:Complex}, Val{:Default})
145+
@test err_func(df, df_ref) < 1e-4
146+
DiffEqDiffTools.finite_difference_jacobian!(J, df, f, x, Val{:central}, Val{:Complex}, Val{:Default})
147+
@test err_func(df, df_ref) < 1e-8
148+
DiffEqDiffTools.finite_difference_jacobian!(J, df, f, x, Val{:forward}, Val{:Complex}, Val{:Default}, y)
149+
@test err_func(df, df_ref) < 1e-4
150+
DiffEqDiffTools.finite_difference_jacobian!(J, df, f, x, Val{:central}, Val{:Complex}, Val{:Default}, y)
151+
@test err_func(df, df_ref) < 1e-8
152+
DiffEqDiffTools.finite_difference_jacobian!(J, df, f, x, Val{:forward}, Val{:Complex}, Val{:Default}, y, epsilon)
153+
@test err_func(df, df_ref) < 1e-4
154+
DiffEqDiffTools.finite_difference_jacobian!(J, df, f, x, Val{:central}, Val{:Complex}, Val{:Default}, y, epsilon)
155+
@test err_func(df, df_ref) < 1e-8
156+
end
119157
# StridedArray tests end here

0 commit comments

Comments
 (0)