Skip to content

Commit ca50299

Browse files
inplace functions
1 parent 7d51ca3 commit ca50299

File tree

2 files changed

+124
-10
lines changed

2 files changed

+124
-10
lines changed

src/jacobians.jl

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,123 @@ function _finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f,
2727
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
2828
fx, epsilon, returntype, inplace::Type{Val{true}})
2929

30+
# TODO: test and rework this to support GPUArrays and non-indexable types, if possible
31+
m, n = size(J)
32+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
33+
if fdtype == Val{:forward}
34+
if typeof(fx) == Void
35+
fx = similar(x,returntype)
36+
end
37+
f(fx,x)
38+
# TODO: Remove these allocations
39+
fx2 = similar(x,returntype)
40+
shifted_x = copy(x)
41+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
42+
@inbounds for i in 1:n
43+
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
44+
shifted_x[i] += epsilon
45+
f(fx2,shifted_x)
46+
J[:, i] .= (fx2 - fx) / epsilon
47+
shifted_x[i] = x[i]
48+
end
49+
elseif fdtype == Val{:central}
50+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
51+
if typeof(fx) == Void
52+
fx1 = similar(x,returntype)
53+
else
54+
fx1 = fx
55+
end
56+
# TODO: Remove these allocations
57+
fx2 = similar(x,returntype)
58+
shifted_x_plus = copy(x)
59+
shifted_x_minus = copy(x)
60+
@inbounds for i in 1:n
61+
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
62+
shifted_x_plus[i] += epsilon
63+
shifted_x_minus[i] -= epsilon
64+
f(fx1,shifted_x_plus)
65+
f(fx2,shifted_x_minus)
66+
J[:, i] .= (fx1 - fx2) / (epsilon + epsilon)
67+
shifted_x_plus[i] = x[i]
68+
shifted_x_minus[i] = x[i]
69+
end
70+
elseif fdtype == Val{:complex}
71+
x0 = Vector{Complex{eltype(x)}}(x)
72+
epsilon = eps(eltype(x))
73+
fx1 = similar(x,Complex{eltype(x)})
74+
@inbounds for i in 1:n
75+
x0[i] += im * epsilon
76+
f(fx1,x0)
77+
J[:,i] .= imag.(fx1) / epsilon
78+
x0[i] -= im * epsilon
79+
end
80+
else
81+
fdtype_error(Val{:Real})
82+
end
83+
J
84+
end
85+
86+
function _finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f,
87+
x::AbstractArray{<:Number},
88+
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
89+
fx, epsilon, returntype, inplace::Type{Val{true}})
90+
91+
# TODO: test and rework this to support GPUArrays and non-indexable types, if possible
92+
m, n = size(J)
93+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
94+
if fdtype == Val{:forward}
95+
96+
if typeof(fx) == Void
97+
fx = similar(x,returntype)
98+
end
99+
f(fx,x)
100+
# TODO: Remove these allocations
101+
fx2 = similar(x,returntype)
102+
shifted_x = copy(x)
103+
104+
epsilon_factor = compute_epsilon_factor(Val{:forward}, epsilon_elemtype)
105+
106+
@inbounds for i in 1:n
107+
epsilon = compute_epsilon(Val{:forward}, real(x[i]), epsilon_factor)
108+
shifted_x[i] += epsilon
109+
f(fx2,shifted_x)
110+
@. J[:, i] = ( real(fx2 - fx ) + im*imag( fx2 - fx ) ) / epsilon
111+
shifted_x[i] = x[i]
112+
end
113+
elseif fdtype == Val{:central}
114+
epsilon_factor = compute_epsilon_factor(Val{:central}, epsilon_elemtype)
115+
116+
if typeof(fx) == Void
117+
fx1 = similar(x,returntype)
118+
else
119+
fx1 = fx
120+
end
121+
# TODO: Remove these allocations
122+
fx2 = similar(x,returntype)
123+
shifted_x_plus = copy(x)
124+
shifted_x_minus = copy(x)
125+
126+
@inbounds for i in 1:n
127+
epsilon = compute_epsilon(Val{:central}, real(x[i]), epsilon_factor)
128+
shifted_x_plus[i] += epsilon
129+
shifted_x_minus[i] -= epsilon
130+
f(fx1,shifted_x_plus)
131+
f(fx2,shifted_x_minus)
132+
@. J[:, i] = ( real(fx1 - fx2) + im*imag(fx1 - fx2) ) / (2 * epsilon)
133+
shifted_x_plus[i] = x[i]
134+
shifted_x_minus[i] = x[i]
135+
end
136+
else
137+
fdtype_error(Val{:Complex})
138+
end
139+
J
140+
end
141+
142+
function _finite_difference_jacobian!(J::AbstractMatrix{<:Real}, f,
143+
x::AbstractArray{<:Real},
144+
fdtype::DataType, ::Type{Val{:Real}}, ::Type{Val{:Default}},
145+
fx, epsilon, returntype, inplace::Type{Val{false}})
146+
30147
# TODO: test and rework this to support GPUArrays and non-indexable types, if possible
31148
m, n = size(J)
32149
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
@@ -71,7 +188,7 @@ end
71188
function _finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f,
72189
x::AbstractArray{<:Number},
73190
fdtype::DataType, ::Type{Val{:Complex}}, ::Type{Val{:Default}},
74-
fx, epsilon, returntype, inplace::Type{Val{true}})
191+
fx, epsilon, returntype, inplace::Type{Val{false}})
75192

76193
# TODO: test and rework this to support GPUArrays and non-indexable types, if possible
77194
m, n = size(J)

test/finitedifftests.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,12 @@ df_ref = -sin.(real(x)) + im*cos.(imag(x))
6767
@test err_func(DiffEqDiffTools.finite_difference!(df, f, x, Val{:central}, Val{:Complex}, Val{:Default}, y, epsilon), df_ref) < 1e-8
6868
end
6969

70-
function f(x)
71-
fvec = zeros(x)
70+
function f(fvec,x)
7271
fvec[1] = (x[1]+3)*(x[2]^3-7)+18
7372
fvec[2] = sin(x[2]*exp(x[1])-1)
74-
fvec
7573
end
76-
x = rand(2)
77-
y = f(x)
74+
x = rand(2); y = rand(2)
75+
f(y,x)
7876
J_ref = [[-7+x[2]^3 3*(3+x[1])*x[2]^2]; [exp(x[1])*x[2]*cos(1-exp(x[1])*x[2]) exp(x[1])*cos(1-exp(x[1])*x[2])]]
7977
J = zeros(J_ref)
8078
df = zeros(x)
@@ -108,14 +106,13 @@ epsilon = zeros(x)
108106
@test err_func(DiffEqDiffTools.finite_difference_jacobian!(J, f, x, Val{:complex}, Val{:Real}, Val{:Default}, y, epsilon), J_ref) < 1e-14
109107
end
110108

111-
function f(x)
112-
fvec = zeros(x)
109+
function f(fvec,x)
113110
fvec[1] = (im*x[1]+3)*(x[2]^3-7)+18
114111
fvec[2] = sin(x[2]*exp(x[1])-1)
115-
fvec
116112
end
117113
x = rand(2) + im*rand(2)
118-
y = f(x)
114+
y = similar(x)
115+
f(y,x)
119116
J_ref = [[im*(-7+x[2]^3) 3*(3+im*x[1])*x[2]^2]; [exp(x[1])*x[2]*cos(1-exp(x[1])*x[2]) exp(x[1])*cos(1-exp(x[1])*x[2])]]
120117
J = zeros(J_ref)
121118
df = zeros(x)

0 commit comments

Comments
 (0)