Skip to content

Commit fbe6dbc

Browse files
dextoriousChrisRackauckas
authored andcommitted
Gradient bugfixes.
1 parent 09101c7 commit fbe6dbc

File tree

4 files changed

+36
-17
lines changed

4 files changed

+36
-17
lines changed

src/DiffEqDiffTools.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module DiffEqDiffTools
55
include("function_wrappers.jl")
66
include("finitediff.jl")
77
include("derivatives.jl")
8+
include("gradients.jl")
89
include("jacobians.jl")
910

1011
end # module

src/finitediff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ end
2020
elseif fdtype==Val{:central}
2121
return cbrt(eps(T))
2222
else
23-
error("Unrecognized fdtype: must be Val{:forward} or Val{:central}.")
23+
error("Unrecognized fdtype $fdtype: must be Val{:forward} or Val{:central}.")
2424
end
2525
end
2626

src/gradients.jl

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct GradientCache{CacheType, CacheType2, CacheType3, fdtype, RealOrComplex}
1+
struct GradientCache{CacheType1, CacheType2, CacheType3, fdtype, RealOrComplex}
22
fx :: CacheType1
33
c1 :: CacheType2
44
c2 :: CacheType3
@@ -7,19 +7,15 @@ end
77
function GradientCache(
88
df :: AbstractArray{<:Number},
99
x :: Union{<:Number, AbstractArray{<:Number}},
10-
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
11-
c1 :: AbstractArray{<:Number} = nothing,
12-
c2 :: AbstractArray{<:Number} = nothing,
10+
fx :: Union{Void,<:Number,AbstractArray{<:Number}} = nothing,
11+
c1 :: Union{Void,AbstractArray{<:Number}} = nothing,
12+
c2 :: Union{Void,AbstractArray{<:Number}} = nothing,
1313
fdtype :: DataType = Val{:central},
1414
RealOrComplex :: DataType =
15-
fdtype==Val{:complex} ? Val{:Real} : eltype(x1) <: Complex ? Val{:Complex} : Val{:Real}
15+
fdtype==Val{:complex} ? Val{:Real} : eltype(x) <: Complex ? Val{:Complex} : Val{:Real}
1616
)
1717

18-
if fdtype == Val{:complex} && RealOrComplex == Val{:Complex}
19-
fdtype_error(Val{:Complex})
20-
end
21-
22-
if fdtype != Val{:forward}
18+
if fdtype != Val{:forward} && typeof(fx) != Void
2319
warn("Pre-computed function values are only useful for fdtype == Val{:forward}.")
2420
_fx = nothing
2521
else
@@ -28,8 +24,8 @@ function GradientCache(
2824
end
2925

3026
if typeof(x) <: AbstractArray # the f:R^n->R case
31-
# need cache arrays for epsilon and x1
32-
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
27+
# need cache arrays for epsilon (c1) and x1 (c2)
28+
epsilon_elemtype = compute_epsilon_elemtype(nothing, x)
3329
if typeof(c1) == Void || eltype(c1) != epsilon_elemtype
3430
_c1 = zeros(epsilon_elemtype, size(x))
3531
else
@@ -56,7 +52,7 @@ function GradientCache(
5652
_c2 = c2
5753
end
5854
end
59-
GradientCache{typeof(fx),typeof(c1),typeof(c2),fdtype,RealOrComplex}(fx,c1,c2)
55+
GradientCache{typeof(_fx),typeof(_c1),typeof(_c2),fdtype,RealOrComplex}(_fx,_c1,_c2)
6056
end
6157

6258
function finite_difference_gradient(f, x, fdtype::DataType=Val{:central},
@@ -107,20 +103,18 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
107103
# c2 is x1, pre-set to the values of x by the cache constructor
108104
fx, c1, c2 = cache.fx, cache.c1, cache.c2
109105
if fdtype == Val{:forward}
110-
# TODO: do we even need c2 for the forward case?
111106
@inbounds for i eachindex(x)
112107
c2[i] += c1[i]
113108
df[i] = (f(c2) - f(x)) / c1[i]
114109
c2[i] -= c1[i]
115110
end
116111
elseif fdtype == Val{:central}
117112
@inbounds for i eachindex(x)
118-
#copy!(x1,x)
119113
c2[i] += c1[i]
120114
x[i] -= c1[i]
121115
df[i] = (f(c2) - f(x)) / (2*c1[i])
122116
c2[i] -= c1[i]
123-
x[i] += c1[i] # revert any changes to x
117+
x[i] += c1[i]
124118
end
125119
elseif fdtype == Val{:complex}
126120
# TODO

test/finitedifftests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,30 @@ end
8585
@test err_func(DiffEqDiffTools.finite_difference_derivative!(df, f, x, central_cache), df_ref) < 1e-8
8686
end
8787

88+
f(x) = 2x[1] + x[2]^2
89+
x = rand(2)
90+
fx = f(x)
91+
df = zeros(2)
92+
df_ref = [2., 2*x[2]]
93+
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward},Val{:Real})
94+
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central},Val{:Real})
95+
#complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:complex},Val{:Real})
96+
97+
@time @testset "Gradient of f:R^n->R real-valued tests" begin
98+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
99+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}), df_ref) < 1e-8
100+
#@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:complex}), df_ref) < 1e-15
101+
102+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}), df_ref) < 1e-4
103+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}), df_ref) < 1e-8
104+
#@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:complex}), df_ref) < 1e-15
105+
106+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, forward_cache), df_ref) < 1e-4
107+
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, central_cache), df_ref) < 1e-8
108+
#@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, complex_cache), df_ref) < 1e-15
109+
end
110+
111+
88112
function f(fvec,x)
89113
fvec[1] = (x[1]+3)*(x[2]^3-7)+18
90114
fvec[2] = sin(x[2]*exp(x[1])-1)

0 commit comments

Comments
 (0)