1
- struct GradientCache{CacheType , CacheType2, CacheType3, fdtype, RealOrComplex}
1
+ struct GradientCache{CacheType1 , CacheType2, CacheType3, fdtype, RealOrComplex}
2
2
fx :: CacheType1
3
3
c1 :: CacheType2
4
4
c2 :: CacheType3
7
7
function GradientCache (
8
8
df :: AbstractArray{<:Number} ,
9
9
x :: Union{<:Number, AbstractArray{<:Number}} ,
10
- fx :: Union{Void,AbstractArray{<:Number}} = nothing ,
11
- c1 :: AbstractArray{<:Number} = nothing ,
12
- c2 :: AbstractArray{<:Number} = nothing ,
10
+ fx :: Union{Void,<:Number, AbstractArray{<:Number}} = nothing ,
11
+ c1 :: Union{Void, AbstractArray{<:Number} } = nothing ,
12
+ c2 :: Union{Void, AbstractArray{<:Number} } = nothing ,
13
13
fdtype :: DataType = Val{:central },
14
14
RealOrComplex :: DataType =
15
- fdtype== Val{:complex } ? Val{:Real } : eltype (x1 ) <: Complex ? Val{:Complex } : Val{:Real }
15
+ fdtype== Val{:complex } ? Val{:Real } : eltype (x ) <: Complex ? Val{:Complex } : Val{:Real }
16
16
)
17
17
18
- if fdtype == Val{:complex } && RealOrComplex == Val{:Complex }
19
- fdtype_error (Val{:Complex })
20
- end
21
-
22
- if fdtype != Val{:forward }
18
+ if fdtype != Val{:forward } && typeof (fx) != Void
23
19
warn (" Pre-computed function values are only useful for fdtype == Val{:forward}." )
24
20
_fx = nothing
25
21
else
@@ -28,8 +24,8 @@ function GradientCache(
28
24
end
29
25
30
26
if typeof (x) <: AbstractArray # the f:R^n->R case
31
- # need cache arrays for epsilon and x1
32
- epsilon_elemtype = compute_epsilon_elemtype (epsilon , x)
27
+ # need cache arrays for epsilon (c1) and x1 (c2)
28
+ epsilon_elemtype = compute_epsilon_elemtype (nothing , x)
33
29
if typeof (c1) == Void || eltype (c1) != epsilon_elemtype
34
30
_c1 = zeros (epsilon_elemtype, size (x))
35
31
else
@@ -56,7 +52,7 @@ function GradientCache(
56
52
_c2 = c2
57
53
end
58
54
end
59
- GradientCache {typeof(fx ),typeof(c1 ),typeof(c2 ),fdtype,RealOrComplex} (fx,c1,c2 )
55
+ GradientCache {typeof(_fx ),typeof(_c1 ),typeof(_c2 ),fdtype,RealOrComplex} (_fx,_c1,_c2 )
60
56
end
61
57
62
58
function finite_difference_gradient (f, x, fdtype:: DataType = Val{:central },
@@ -107,20 +103,18 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
107
103
# c2 is x1, pre-set to the values of x by the cache constructor
108
104
fx, c1, c2 = cache. fx, cache. c1, cache. c2
109
105
if fdtype == Val{:forward }
110
- # TODO : do we even need c2 for the forward case?
111
106
@inbounds for i ∈ eachindex (x)
112
107
c2[i] += c1[i]
113
108
df[i] = (f (c2) - f (x)) / c1[i]
114
109
c2[i] -= c1[i]
115
110
end
116
111
elseif fdtype == Val{:central }
117
112
@inbounds for i ∈ eachindex (x)
118
- # copy!(x1,x)
119
113
c2[i] += c1[i]
120
114
x[i] -= c1[i]
121
115
df[i] = (f (c2) - f (x)) / (2 * c1[i])
122
116
c2[i] -= c1[i]
123
- x[i] += c1[i] # revert any changes to x
117
+ x[i] += c1[i]
124
118
end
125
119
elseif fdtype == Val{:complex }
126
120
# TODO
0 commit comments