Skip to content

Commit 09101c7

Browse files
dextoriousChrisRackauckas
authored andcommitted
Gradient prototype implementation, Val{:central} with Val{:Real} should work.
1 parent d1f565b commit 09101c7

File tree

2 files changed

+181
-50
lines changed

2 files changed

+181
-50
lines changed

src/derivatives.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,38 +82,38 @@ function DerivativeCache(
8282
Val{:Complex} : Val{:Real}
8383
)
8484

85-
if typeof(x)<:StridedArray && typeof(fx)<:Union{Void,StridedArray}
86-
if typeof(epsilon)!=Void
87-
warn("StridedArrays don't benefit from pre-allocating epsilon.")
88-
epsilon = nothing
89-
end
85+
if fdtype == Val{:complex} && RealOrComplex == Val{:Complex}
86+
fdtype_error(Val{:Complex})
9087
end
91-
if fdtype == Val{:complex}
92-
if RealOrComplex == Val{:Complex}
93-
fdtype_error(Val{:Complex})
94-
end
95-
if typeof(fx) != Void || typeof(epsilon) != Void
96-
warn("Val{:complex} doesn't benefit from cache arrays.")
97-
end
98-
return DerivativeCache{Void,Void,fdtype,RealOrComplex}(nothing, nothing)
88+
89+
if fdtype != Val{:forward}
90+
warn("Pre-computed function values are only useful for fdtype == Val{:forward}.")
91+
_fx = nothing
92+
else
93+
# more runtime sanity checks?
94+
_fx = fx
95+
end
96+
97+
if typeof(epsilon) == Void
98+
_epsilon = nothing
9999
else
100-
if !(typeof(x)<:StridedArray && typeof(fx)<:Union{Void,StridedArray})
100+
if typeof(x)<:StridedArray && typeof(fx)<:Union{Void,StridedArray}
101+
warn("StridedArrays don't benefit from pre-allocating epsilon.")
102+
_epsilon = nothing
103+
elseif fdtype == Val{:complex}
104+
warn("Val{:complex} makes the epsilon array redundant.")
105+
_epsilon = nothing
106+
else
101107
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
102108
if typeof(epsilon) == Void || eltype(epsilon) != epsilon_elemtype
103109
epsilon = zeros(epsilon_elemtype, size(x))
104110
end
105111
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
106112
@. epsilon = compute_epsilon(fdtype, real(x), epsilon_factor)
107-
end
108-
if fdtype != Val{:forward}
109-
if typeof(fx) != Void
110-
warn("Pre-computed function values are only useful for fdtype == Val{:forward}.")
111-
end
112-
return DerivativeCache{Void,typeof(epsilon),fdtype,RealOrComplex}(nothing,epsilon)
113-
else
114-
return DerivativeCache{typeof(fx),typeof(epsilon),fdtype,RealOrComplex}(fx,epsilon)
113+
_epsilon = epsilon
115114
end
116115
end
116+
DerivativeCache{typeof(_fx),typeof(_epsilon),fdtype,RealOrComplex}(_fx,_epsilon)
117117
end
118118

119119
#=

src/gradients.jl

Lines changed: 159 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,167 @@
1-
struct GradientCache{CacheType,CacheType2,CacheType3,fdtype,RealOrComplex}
2-
x1::CacheType
3-
fx::CacheType2
4-
fx1::CacheType3
5-
end
6-
7-
function GradientCache(x1,fx,_fx1,fdtype::DataType=Val{:central},
8-
RealOrComplex::DataType =
9-
fdtype==Val{:complex} ? Val{:Real} : eltype(x1) <: Complex ?
10-
Val{:Complex} : Val{:Real})
11-
if fdtype == Val{:complex} && _fx1 != nothing
12-
warn("fx1 cache is ignored when fdtype == Val{:complex}.")
13-
fx1 = nothing
1+
struct GradientCache{CacheType, CacheType2, CacheType3, fdtype, RealOrComplex}
2+
fx :: CacheType1
3+
c1 :: CacheType2
4+
c2 :: CacheType3
5+
end
6+
7+
function GradientCache(
8+
df :: AbstractArray{<:Number},
9+
x :: Union{<:Number, AbstractArray{<:Number}},
10+
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
11+
c1 :: AbstractArray{<:Number} = nothing,
12+
c2 :: AbstractArray{<:Number} = nothing,
13+
fdtype :: DataType = Val{:central},
14+
RealOrComplex :: DataType =
15+
fdtype==Val{:complex} ? Val{:Real} : eltype(x1) <: Complex ? Val{:Complex} : Val{:Real}
16+
)
17+
18+
if fdtype == Val{:complex} && RealOrComplex == Val{:Complex}
19+
fdtype_error(Val{:Complex})
20+
end
21+
22+
if fdtype != Val{:forward}
23+
warn("Pre-computed function values are only useful for fdtype == Val{:forward}.")
24+
_fx = nothing
25+
else
26+
# more runtime sanity checks?
27+
_fx = fx
28+
end
29+
30+
if typeof(x) <: AbstractArray # the f:R^n->R case
31+
# need cache arrays for epsilon and x1
32+
epsilon_elemtype = compute_epsilon_elemtype(epsilon, x)
33+
if typeof(c1) == Void || eltype(c1) != epsilon_elemtype
34+
_c1 = zeros(epsilon_elemtype, size(x))
35+
else
36+
_c1 = c1
37+
end
38+
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
39+
@. _c1 = compute_epsilon(fdtype, real(x), epsilon_factor)
40+
41+
if typeof(c2) != typeof(x) || size(c2) != size(x)
42+
_c2 = copy(x)
43+
else
44+
copy!(_c2, x)
45+
end
46+
else # the f:R->R^n case
47+
# need cache arrays for fx1 and fx2
48+
if typeof(c1) != typeof(df) || size(c1) != size(df)
49+
_c1 = similar(df)
50+
else
51+
_c1 = c1
52+
end
53+
if typeof(c2) != typeof(df) || size(c2) != size(df)
54+
_c2 = similar(df)
55+
else
56+
_c2 = c2
57+
end
58+
end
59+
GradientCache{typeof(fx),typeof(c1),typeof(c2),fdtype,RealOrComplex}(fx,c1,c2)
60+
end
61+
62+
function finite_difference_gradient(f, x, fdtype::DataType=Val{:central},
63+
RealOrComplex::DataType =
64+
fdtype==Val{:complex} ? Val{:Real} : eltype(x) <: Complex ? Val{:Complex} : Val{:Real},
65+
fx::Union{Void,AbstractArray{<:Number}}=nothing,
66+
c1::Union{Void,AbstractArray{<:Number}}=nothing,
67+
c2::Union{Void,AbstractArray{<:Number}}=nothing,
68+
)
69+
70+
if typeof(x) <: AbstractArray
71+
df = similar(x)
1472
else
15-
fx1 = _fx1
73+
df = similar(f(x)) # can we get rid of this by requesting more information?
1674
end
17-
GradientCache{typeof(x1),typeof(fx),typeof(fx1),
18-
fdtype,RealOrComplex}(x1,fx,fx1)
75+
cache = GradientCache(df,x,fx,c1,c2,fdtype,RealOrComplex)
76+
finite_difference_gradient!(df,f,x,cache)
1977
end
2078

21-
function finite_difference_gradient(f,x,fdtype=Val{:central},
22-
RealOrComplex::DataType =
23-
fdtype==Val{:complex} ? Val{:Real} : eltype(x) <: Complex ?
24-
Val{:Complex} : Val{:Real})
25-
x1 = similar(x)
26-
fx = similar(x)
27-
fx1 = similar(x)
28-
cache = GradientCache(x1,fx,fx1,fdtype,RealOrComplex)
29-
finite_difference_gradient(f,x,cache)
79+
function finite_difference_gradient!(df, f, x, fdtype::DataType=Val{:central},
80+
RealOrComplex::DataType =
81+
fdtype==Val{:complex} ? Val{:Real} : eltype(x) <: Complex ? Val{:Complex} : Val{:Real},
82+
fx::Union{Void,AbstractArray{<:Number}}=nothing,
83+
c1::Union{Void,AbstractArray{<:Number}}=nothing,
84+
c2::Union{Void,AbstractArray{<:Number}}=nothing,
85+
)
86+
87+
cache = GradientCache(df,x,fx,c1,c2,fdtype,RealOrComplex)
88+
finite_difference_gradient!(df,f,x,cache)
3089
end
3190

3291
function finite_difference_gradient(f,x,cache::GradientCache)
33-
G = zeros(eltype(x), length(x), length(x))
34-
finite_difference_gradient!(J,f,x,cache)
35-
J
92+
if typeof(x) <: AbstractArray
93+
df = similar(x)
94+
else
95+
df = similar(cache.c1)
96+
end
97+
finite_difference_gradient!(df,f,x,cache)
98+
df
99+
end
100+
101+
# vector of derivatives of f : R^n -> R by each component of a vector x
102+
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
103+
cache::GradientCache{T1,T2,T3,fdtype,Val{:Real}}) where {T1,T2,T3,fdtype}
104+
105+
# NOTE: in this case epsilon is a vector, we need two arrays for epsilon and x1
106+
# c1 denotes epsilon (pre-computed by the cache constructor),
107+
# c2 is x1, pre-set to the values of x by the cache constructor
108+
fx, c1, c2 = cache.fx, cache.c1, cache.c2
109+
if fdtype == Val{:forward}
110+
# TODO: do we even need c2 for the forward case?
111+
@inbounds for i eachindex(x)
112+
c2[i] += c1[i]
113+
df[i] = (f(c2) - f(x)) / c1[i]
114+
c2[i] -= c1[i]
115+
end
116+
elseif fdtype == Val{:central}
117+
@inbounds for i eachindex(x)
118+
#copy!(x1,x)
119+
c2[i] += c1[i]
120+
x[i] -= c1[i]
121+
df[i] = (f(c2) - f(x)) / (2*c1[i])
122+
c2[i] -= c1[i]
123+
x[i] += c1[i] # revert any changes to x
124+
end
125+
elseif fdtype == Val{:complex}
126+
# TODO
127+
end
128+
df
129+
end
130+
131+
# vector of derivatives of f : R -> R^n
132+
# this is effectively a vector of partial derivatives, but we still call it a gradient
133+
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
134+
cache::GradientCache{T1,T2,T3,fdtype,Val{:Real}}) where {T1,T2,T3,fdtype}
135+
136+
# NOTE: in this case epsilon is a scalar, we need two arrays for fx1 and fx2
137+
# c1 denotes fx1, c2 is fx2, sizes guaranteed by the cache constructor
138+
fx, c1, c2 = cache.fx, cache.c1, cache.c2
139+
140+
if fdtype == Val{:forward}
141+
# TODO
142+
elseif fdtype == Val{:central}
143+
c1 .= f(x+epsilon)
144+
c2 .= f(x-epsilon)
145+
@inbounds for i 1 : length(fx)
146+
df[i] = (f(x+epsilon)[1] - f(x-epsilon)[1]) / (2*epsilon)
147+
end
148+
elseif fdtype == Val{:complex}
149+
# TODO
150+
end
151+
df
152+
end
153+
154+
155+
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::AbstractArray{<:Number},
156+
cache::GradientCache{T1,T2,T3,fdtype,Val{:Complex}}) where {T1,T2,T3,fdtype}
157+
158+
# TODO
159+
df
160+
end
161+
162+
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
163+
cache::GradientCache{T1,T2,T3,fdtype,Val{:Complex}}) where {T1,T2,T3,fdtype}
164+
165+
# TODO
166+
df
36167
end

0 commit comments

Comments
 (0)