Skip to content

Commit 00a0cf5

Browse files
dextoriousChrisRackauckas
authored andcommitted
Changed gradient convention to conjugate derivatives, added more tests.
1 parent 45728b2 commit 00a0cf5

File tree

3 files changed

+88
-60
lines changed

3 files changed

+88
-60
lines changed

src/derivatives.jl

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,51 +14,6 @@ function finite_difference_derivative(f, x::T, fdtype::Type{T1}=Val{:central},
1414
end
1515
fdtype_error(returntype)
1616
end
17-
18-
#=
19-
Finite difference kernels for single point derivatives.
20-
These are currently unused because of inlining / broadcast issues.
21-
Revisit this in Julia v0.7 / 1.0.
22-
=#
23-
#=
24-
@inline function _finite_difference_kernel(f, x::T, ::Type{Val{:forward}}, ::Type{Val{:Real}},
25-
epsilon::T, fx::Union{Void,T}=nothing) where T<:Real
26-
27-
if typeof(fx) == Void
28-
return (f(x+epsilon) - f(x)) / epsilon
29-
else
30-
return (f(x+epsilon) - fx) / epsilon
31-
end
32-
end
33-
34-
@inline function _finite_difference_kernel(f, x::T, ::Type{Val{:central}}, ::Type{Val{:Real}},
35-
epsilon::T, ::Union{Void,T}=nothing) where T<:Real
36-
37-
(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
38-
end
39-
40-
@inline function _finite_difference_kernel(f, x::T, ::Type{Val{:complex}}, ::Type{Val{:Real}},
41-
epsilon::T, ::Union{Void,T}=nothing) where T<:Real
42-
43-
imag(f(x+im*epsilon)) / epsilon
44-
end
45-
46-
@inline function _finite_difference_kernel(f, x::Number, ::Type{Val{:forward}}, ::Type{Val{:Complex}},
47-
epsilon::Real, fx::Union{Void,<:Number}=nothing)
48-
49-
if typeof(fx) == Void
50-
return real((f(x+epsilon) - f(x))) / epsilon + im*imag((f(x+im*epsilon) - f(x))) / epsilon
51-
else
52-
return real((f(x+epsilon) - fx)) / epsilon + im*imag((f(x+im*epsilon) - fx)) / epsilon
53-
end
54-
end
55-
56-
@inline function _finite_difference_kernel(f, x::Number, ::Type{Val{:central}}, ::Type{Val{:Complex}},
57-
epsilon::Real, fx::Union{Void,<:Number}=nothing)
58-
59-
real(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon) + im*imag(f(x+im*epsilon) - f(x-im*epsilon)) / (2 * epsilon)
60-
end
61-
=#
6217
# Single point derivative implementations end here.
6318

6419

src/gradients.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
169169
dfi = (f(c1) - fx0) / (im*epsilon)
170170
end
171171
c1[i] = c1_old
172-
df[i] += im * imag(dfi)
172+
df[i] -= im * imag(dfi)
173173
end
174174
end
175175
elseif fdtype == Val{:central}
@@ -185,7 +185,7 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
185185
if eltype(df)<:Complex
186186
c1[i] += im*epsilon
187187
x[i] -= im*epsilon
188-
df[i] += im*imag( (f(c1) - f(x)) / (2*im*epsilon) )
188+
df[i] -= im*imag( (f(c1) - f(x)) / (2*im*epsilon) )
189189
c1[i] = c1_old
190190
x[i] = x_old
191191
end
@@ -251,7 +251,7 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
251251
end
252252
c1[i] = x_old
253253
end
254-
df[i] += im * imag(dfi)
254+
df[i] -= im * imag(dfi)
255255
end
256256
end
257257
elseif fdtype == Val{:central}
@@ -278,7 +278,7 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
278278
dfi -= f(c1)
279279
c1[i] = x_old
280280
end
281-
df[i] += im*imag(dfi / (2*im*epsilon))
281+
df[i] -= im*imag(dfi / (2*im*epsilon))
282282
end
283283
end
284284
elseif fdtype==Val{:complex} && returntype<:Real && eltype(df)<:Real && eltype(x)<:Real

test/finitedifftests.jl

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ complex_cache = DiffEqDiffTools.DerivativeCache(x, nothing, nothing, Val{:comple
1414

1515
err_func(a,b) = maximum(abs.(a-b))
1616

17-
@time @testset "Derivative single point real-valued tests" begin
17+
@time @testset "Derivative single point f : R -> R tests" begin
1818
@test err_func(DiffEqDiffTools.finite_difference_derivative(sin, π/4, Val{:forward}), 2/2) < 1e-4
1919
@test err_func(DiffEqDiffTools.finite_difference_derivative(sin, π/4, Val{:central}), 2/2) < 1e-8
2020
@test err_func(DiffEqDiffTools.finite_difference_derivative(sin, π/4, Val{:complex}), 2/2) < 1e-15
2121
end
2222

23-
@time @testset "Derivative StridedArray real-valued tests" begin
23+
@time @testset "Derivative StridedArray f : R -> R tests" begin
2424
@test err_func(DiffEqDiffTools.finite_difference_derivative(sin, x, Val{:forward}), df_ref) < 1e-4
2525
@test err_func(DiffEqDiffTools.finite_difference_derivative(sin, x, Val{:central}), df_ref) < 1e-8
2626
@test err_func(DiffEqDiffTools.finite_difference_derivative(sin, x, Val{:complex}), df_ref) < 1e-15
@@ -51,12 +51,12 @@ df_ref = cos.(x) - sin.(x)
5151
forward_cache = DiffEqDiffTools.DerivativeCache(x, y, epsilon, Val{:forward})
5252
central_cache = DiffEqDiffTools.DerivativeCache(x, y, epsilon, Val{:central})
5353

54-
@time @testset "Derivative single point complex-valued tests" begin
54+
@time @testset "Derivative single point f : C -> C tests" begin
5555
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, π/4+im*π/4, Val{:forward}, Val{:Complex}), cos/4+im*π/4)-sin/4+im*π/4)) < 1e-3
5656
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, π/4+im*π/4, Val{:central}, Val{:Complex}), cos/4+im*π/4)-sin/4+im*π/4)) < 1e-7
5757
end
5858

59-
@time @testset "Derivative StridedArray complex-valued tests" begin
59+
@time @testset "Derivative StridedArray f : C -> C tests" begin
6060
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:forward}), df_ref) < 1e-3
6161
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:central}), df_ref) < 1e-6
6262

@@ -79,6 +79,83 @@ end
7979
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, central_cache), df_ref) < 1e-6
8080
end
8181

82+
x = collect(linspace(-2π, 2π, 100))
83+
f(x) = sin(x) + im*cos(x)
84+
y = f.(x)
85+
df = zeros(Complex{eltype(x)}, size(x))
86+
epsilon = similar(real(x))
87+
df_ref = cos.(x) - im*sin.(x)
88+
forward_cache = DiffEqDiffTools.DerivativeCache(x, y, epsilon, Val{:forward}, eltype(df))
89+
central_cache = DiffEqDiffTools.DerivativeCache(x, y, epsilon, Val{:central}, eltype(df))
90+
91+
@time @testset "Derivative single point f : R -> C tests" begin
92+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, π/4, Val{:forward}, Val{:Complex}), cos/4)-im*sin/4)) < 1e-3
93+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, π/4, Val{:central}, Val{:Complex}), cos/4)-im*sin/4)) < 1e-7
94+
end
95+
96+
@time @testset "Derivative StridedArray f : R -> C tests" begin
97+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:forward}, Complex{eltype(x)}), df_ref) < 1e-3
98+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:central}, Complex{eltype(x)}), df_ref) < 1e-6
99+
100+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:forward}, Complex{eltype(x)}, y), df_ref) < 1e-3
101+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:central}, Complex{eltype(x)}, y), df_ref) < 1e-6
102+
103+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:forward}, Complex{eltype(x)}, y, epsilon), df_ref) < 1e-3
104+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:central}, Complex{eltype(x)}, y, epsilon), df_ref) < 1e-6
105+
106+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:forward}, Complex{eltype(x)}), df_ref) < 1e-3
107+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:central}, Complex{eltype(x)}), df_ref) < 1e-6
108+
109+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:forward}, Complex{eltype(x)}, y), df_ref) < 1e-3
110+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:central}, Complex{eltype(x)}, y), df_ref) < 1e-6
111+
112+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:forward}, Complex{eltype(x)}, y, epsilon), df_ref) < 1e-3
113+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:central}, Complex{eltype(x)}, y, epsilon), df_ref) < 1e-6
114+
115+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, forward_cache), df_ref) < 1e-3
116+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, central_cache), df_ref) < 1e-6
117+
end
118+
119+
#=
120+
x = x + im*x
121+
f(x) = abs2(x)
122+
y = f.(x)
123+
df = zeros(eltype(x), size(x))
124+
epsilon = similar(real(x))
125+
df_ref = 2*conj.(x)
126+
forward_cache = DiffEqDiffTools.DerivativeCache(x, y, epsilon, Val{:forward}, eltype(df))
127+
central_cache = DiffEqDiffTools.DerivativeCache(x, y, epsilon, Val{:central}, eltype(df))
128+
@show typeof(DiffEqDiffTools.finite_difference_derivative(f, 1.+im*1., Val{:forward}, real(eltype(x))))
129+
130+
@time @testset "Derivative single point f : C -> R tests" begin
131+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, 1.+im*1., Val{:forward}, real(eltype(x))), 2.-2.*im) < 1e-3
132+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, 1.+im*1., Val{:central}, real(eltype(x))), 2.-2.*im) < 1e-7
133+
end
134+
135+
@time @testset "Derivative StridedArray f : C -> R tests" begin
136+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:forward}, real(eltype(x))), df_ref) < 1e-3
137+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:central}, real(eltype(x))), df_ref) < 1e-6
138+
139+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:forward}, real(eltype(x)), y), df_ref) < 1e-3
140+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:central}, real(eltype(x)), y), df_ref) < 1e-6
141+
142+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:forward}, real(eltype(x)), y, epsilon), df_ref) < 1e-3
143+
@test err_func(DiffEqDiffTools.finite_difference_derivative(f, x, Val{:central}, real(eltype(x)), y, epsilon), df_ref) < 1e-6
144+
145+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:forward}, real(eltype(x))), df_ref) < 1e-3
146+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:central}, real(eltype(x))), df_ref) < 1e-6
147+
148+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:forward}, real(eltype(x)), y), df_ref) < 1e-3
149+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:central}, real(eltype(x)), y), df_ref) < 1e-6
150+
151+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:forward}, real(eltype(x)), y, epsilon), df_ref) < 1e-3
152+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, Val{:central}, real(eltype(x)), y, epsilon), df_ref) < 1e-6
153+
154+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, forward_cache), df_ref) < 1e-3
155+
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, central_cache), df_ref) < 1e-6
156+
end
157+
=#
158+
82159
# Gradient tests
83160
f(x) = 2x[1] + x[2]^2
84161
x = rand(2)
@@ -107,9 +184,7 @@ f(x) = 2x[1] + im*2x[1] + x[2]^2
107184
x = x + im*x
108185
fx = f(x)
109186
df = zeros(x)
110-
df_ref = [2.0+2.0*im, 2.0*x[2]]
111-
DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward})
112-
DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central})
187+
df_ref = conj([2.0+2.0*im, 2.0*x[2]])
113188
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
114189
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
115190

@@ -128,9 +203,7 @@ f(x) = sum(abs2, x)
128203
x = ones(2) * (1 + im)
129204
fx = f(x)
130205
df = zeros(x)
131-
df_ref = 2*conj.(x)
132-
DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward})
133-
DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central})
206+
df_ref = 2*x
134207
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
135208
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
136209

@@ -149,7 +222,7 @@ f(x) = 2*x[1] + im*x[2]^2
149222
x = ones(2)
150223
fx = f(x)
151224
df = zeros(eltype(fx), size(x))
152-
df_ref = [2.0, im*2*x[2]]
225+
df_ref = [2.0, -im*2*x[2]]
153226
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward},eltype(df))
154227
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central},eltype(df))
155228

0 commit comments

Comments
 (0)