Skip to content

Commit 44892a2

Browse files
allocating dispatches for the Jacobian
1 parent 00a0cf5 commit 44892a2

File tree

2 files changed

+80
-42
lines changed

2 files changed

+80
-42
lines changed

src/jacobians.jl

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,56 +5,91 @@ struct JacobianCache{CacheType1,CacheType2,CacheType3,fdtype,returntype,inplace}
55
end
66

77
function JacobianCache(
8-
x :: AbstractArray{<:Number},
9-
x1 :: Union{Void,AbstractArray{<:Number}} = nothing,
10-
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
11-
fx1 :: Union{Void,AbstractArray{<:Number}} = nothing,
8+
x,
129
fdtype :: Type{T1} = Val{:central},
1310
returntype :: Type{T2} = eltype(x),
1411
inplace :: Type{Val{T3}} = Val{true}) where {T1,T2,T3}
12+
if eltype(x) <: Real && fdtype==Val{:complex}
13+
x1 = zeros(Complex{eltype(x)}, size(x))
14+
_fx = zeros(Complex{eltype(x)}, size(x))
15+
else
16+
x1 = similar(x)
17+
_fx = similar(x)
18+
end
1519

1620
if fdtype==Val{:complex}
17-
if !(returntype<:Real)
18-
fdtype_error(returntype)
19-
end
20-
if eltype(fx)!=Complex{eltype(x)}
21-
_fx = zeros(Complex{eltype(x)}, size(x))
22-
end
23-
if typeof(fx1)!=Void
24-
warn("fdtype==Val{:complex} doesn't benefit from caching fx1.")
25-
end
26-
_fx1 = nothing
27-
if eltype(x1) != Complex{eltype(x)}
28-
_x1 = zeros(Complex{eltype(x)}, size(x))
29-
else
30-
_x1 = x1
31-
end
21+
_fx1 = nothing
3222
else
33-
if eltype(x1) != eltype(x)
34-
_x1 = similar(x)
35-
else
36-
_x1 = x1
37-
end
38-
if eltype(fx) != returntype
39-
_fx = zeros(returntype, size(x))
23+
_fx1 = similar(x)
24+
end
25+
26+
JacobianCache(x1,_fx,_fx1,fdtype,returntype,inplace)
27+
end
28+
29+
function JacobianCache(
30+
x ,
31+
fx,
32+
fdtype :: Type{T1} = Val{:central},
33+
returntype :: Type{T2} = eltype(x),
34+
inplace :: Type{Val{T3}} = Val{true}) where {T1,T2,T3}
35+
36+
if eltype(x) <: Real && fdtype==Val{:complex}
37+
x1 = zeros(Complex{eltype(x)}, size(x))
38+
else
39+
x1 = similar(x)
40+
end
41+
42+
if eltype(fx) <: Real && fdtype==Val{:complex}
43+
_fx = zeros(Complex{eltype(x)}, size(fx))
44+
else
45+
_fx = similar(fx)
46+
end
47+
48+
if fdtype==Val{:complex}
49+
_fx1 = nothing
50+
else
51+
_fx1 = similar(fx)
52+
end
53+
54+
JacobianCache(x1,_fx,_fx1,fdtype,returntype,inplace)
55+
end
56+
57+
function JacobianCache(
58+
x1 ,
59+
fx ,
60+
fx1,
61+
fdtype :: Type{T1} = Val{:central},
62+
returntype :: Type{T2} = eltype(x),
63+
inplace :: Type{Val{T3}} = Val{true}) where {T1,T2,T3}
64+
65+
if fdtype==Val{:complex}
66+
!(returntype<:Real) && fdtype_error(returntype)
67+
68+
if eltype(fx) <: Real
69+
_fx = zeros(Complex{eltype(x)}, size(fx))
4070
else
4171
_fx = fx
4272
end
43-
if eltype(fx1) != returntype
44-
_fx1 = zeros(returntype, size(x))
73+
if eltype(x1) <: Real
74+
_x1 = zeros(Complex{eltype(x)}, size(x))
4575
else
46-
_fx1 = fx1
76+
_x1 = x1
4777
end
78+
else
79+
_x1 = x1
80+
@assert eltype(fx) == T2
81+
@assert eltype(fx1) == T2
82+
_fx = fx
4883
end
49-
JacobianCache{typeof(_x1),typeof(_fx),typeof(_fx1),fdtype,returntype,inplace}(_x1,_fx,_fx1)
84+
JacobianCache{typeof(_x1),typeof(_fx),typeof(fx1),fdtype,returntype,inplace}(_x1,_fx,fx1)
5085
end
5186

5287
function finite_difference_jacobian(f, x::AbstractArray{<:Number},
5388
fdtype :: Type{T1}=Val{:central},
5489
returntype :: Type{T2}=eltype(x),
5590
inplace :: Type{Val{T3}}=Val{true}) where {T1,T2,T3}
5691

57-
cache = JacobianCache(x,nothing,nothing,nothing,fdtype,returntype,inplace)
92+
cache = JacobianCache(x,fdtype,returntype,inplace)
5893
finite_difference_jacobian(f,x,cache)
5994
end
6095

@@ -81,11 +116,12 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f,x::AbstractA
81116
if inplace == Val{true}
82117
f(fx1, x1)
83118
f(fx, x)
119+
J[:,i] = (vfx1 - vfx) / epsilon
84120
else
85121
fx1 .= f(x1)
86122
fx .= f(x)
123+
J[:,i] = (vfx1 - vfx) / epsilon
87124
end
88-
@. J[:,i] = (vfx1 - vfx) / epsilon
89125
x1[i] = x1_save
90126
end
91127
elseif fdtype == Val{:central}
@@ -100,11 +136,12 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f,x::AbstractA
100136
if inplace == Val{true}
101137
f(fx1, x1)
102138
f(fx, x)
139+
@. J[:,i] = (vfx1 - vfx) / (2*epsilon)
103140
else
104-
fx1 .= f(x1)
105-
fx .= f(x)
141+
fx1 = f(x1)
142+
fx = f(x)
143+
J[:,i] = (vfx1 - vfx) / (2*epsilon)
106144
end
107-
@. J[:,i] = (vfx1 - vfx) / (2*epsilon)
108145
x1[i] = x1_save
109146
x[i] = x_save
110147
end
@@ -115,10 +152,11 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f,x::AbstractA
115152
x1[i] += im * epsilon
116153
if inplace == Val{true}
117154
f(fx,x1)
155+
@. J[:,i] = imag(vfx) / epsilon
118156
else
119-
fx .= f(x1)
157+
fx = f(x1)
158+
J[:,i] = imag(vfx) / epsilon
120159
end
121-
@. J[:,i] = imag(vfx) / epsilon
122160
x1[i] = x1_save
123161
end
124162
else

test/finitedifftests.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,9 @@ J = zeros(J_ref)
294294
df = zeros(x)
295295
df_ref = diag(J_ref)
296296
epsilon = zeros(x)
297-
forward_cache = DiffEqDiffTools.JacobianCache(x,similar(x),similar(x),similar(x),Val{:forward})
298-
central_cache = DiffEqDiffTools.JacobianCache(x,similar(x),similar(x),similar(x))
299-
complex_cache = DiffEqDiffTools.JacobianCache(x,nothing,nothing,nothing,Val{:complex})
297+
forward_cache = DiffEqDiffTools.JacobianCache(x,Val{:forward})
298+
central_cache = DiffEqDiffTools.JacobianCache(x)
299+
complex_cache = DiffEqDiffTools.JacobianCache(x,Val{:complex})
300300

301301
@time @testset "Jacobian StridedArray real-valued tests" begin
302302
@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x, forward_cache), J_ref) < 1e-4
@@ -317,8 +317,8 @@ J = zeros(J_ref)
317317
df = zeros(x)
318318
df_ref = diag(J_ref)
319319
epsilon = zeros(real.(x))
320-
forward_cache = DiffEqDiffTools.JacobianCache(x,similar(x),similar(x),similar(x),Val{:forward})
321-
central_cache = DiffEqDiffTools.JacobianCache(x,similar(x),similar(x),similar(x))
320+
forward_cache = DiffEqDiffTools.JacobianCache(x,Val{:forward})
321+
central_cache = DiffEqDiffTools.JacobianCache(x)
322322

323323
@time @testset "Jacobian StridedArray f : C^N -> C^N tests" begin
324324
@test err_func(DiffEqDiffTools.finite_difference_jacobian(f, x, forward_cache), J_ref) < 1e-4

0 commit comments

Comments
 (0)