Skip to content

Commit 86f03db

Browse files
restructure to separate allocating from non-allocating cache construct
1 parent 9467a91 commit 86f03db

File tree

2 files changed

+69
-46
lines changed

2 files changed

+69
-46
lines changed

src/gradients.jl

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,51 @@ struct GradientCache{CacheType1, CacheType2, CacheType3, fdtype, returntype, inp
55
end
66

77
function GradientCache(
8-
df :: AbstractArray{<:Number},
8+
df :: Union{<:Number,AbstractArray{<:Number}},
99
x :: Union{<:Number, AbstractArray{<:Number}},
10+
fdtype :: Type{T1} = Val{:central},
11+
returntype :: Type{T2} = eltype(df),
12+
inplace :: Type{Val{T3}} = Val{true}) where {T1,T2,T3}
13+
14+
if typeof(x)<:AbstractArray # the vector->scalar case
15+
if fdtype!=Val{:complex} # complex-mode FD only needs one cache, for x+eps*im
16+
if typeof(x)<:StridedVector
17+
if eltype(df)<:Complex && !(eltype(x)<:Complex)
18+
_c1 = zeros(Complex{eltype(x)}, size(x))
19+
_c2 = nothing
20+
else
21+
_c1 = nothing
22+
_c2 = nothing
23+
end
24+
else
25+
_c1 = similar(x)
26+
_c2 = zeros(real(eltype(x)), size(x))
27+
end
28+
else
29+
if !(returntype<:Real)
30+
fdtype_error(returntype)
31+
else
32+
_c1 = x + 0*im
33+
_c2 = nothing
34+
end
35+
end
36+
else # the scalar->vector case
37+
# need cache arrays for fx1 and fx2, except in complex mode, which needs one complex array
38+
if fdtype != Val{:complex}
39+
_c1 = similar(df)
40+
_c2 = similar(df)
41+
else
42+
_c1 = zeros(Complex{eltype(x)}, size(df))
43+
_c2 = nothing
44+
end
45+
end
46+
47+
GradientCache{Void,typeof(_c1),typeof(_c2),fdtype,
48+
returntype,inplace}(nothing,_c1,_c2)
49+
50+
end
51+
52+
function GradientCache(
1053
fx :: Union{Void,<:Number,AbstractArray{<:Number}} = nothing,
1154
c1 :: Union{Void,AbstractArray{<:Number}} = nothing,
1255
c2 :: Union{Void,AbstractArray{<:Number}} = nothing,
@@ -37,16 +80,8 @@ function GradientCache(
3780
end
3881
end
3982
else
40-
if typeof(c1)!=typeof(x) || size(c1)!=size(x)
41-
_c1 = similar(x)
42-
else
43-
_c1 = c1
44-
end
45-
if (typeof(c2)==Void || eltype(c2)!=real(eltype(x)))
46-
_c2 = zeros(real(eltype(x)), size(x))
47-
else
48-
_c2 = c2
49-
end
83+
_c1 = c1
84+
_c2 = c2
5085
end
5186
else
5287
if !(returntype<:Real)
@@ -60,26 +95,10 @@ function GradientCache(
6095
else # the scalar->vector case
6196
# need cache arrays for fx1 and fx2, except in complex mode, which needs one complex array
6297
if fdtype != Val{:complex}
63-
if typeof(c1)==Void || size(c1) != size(df)
64-
_c1 = similar(df)
65-
else
66-
_c1 = c1
67-
end
68-
if fdtype == Val{:forward} && typeof(fx) != Void
69-
_c2 = nothing
70-
else
71-
if typeof(c2) != typeof(df) || size(c2) != size(df)
72-
_c2 = similar(df)
73-
else
74-
_c2 = c2
75-
end
76-
end
98+
_c1 = c1
99+
_c2 = c2
77100
else
78-
if typeof(c1)==Void || size(c1)!=size(df)
79-
_c1 = zeros(Complex{eltype(x)}, size(df))
80-
else
81-
_c1 = c1
82-
end
101+
_c1 = c1
83102
_c2 = nothing
84103
end
85104
end
@@ -108,7 +127,7 @@ function finite_difference_gradient(f, x, fdtype::Type{T1}=Val{:central},
108127
df = similar(f(x))
109128
end
110129
end
111-
cache = GradientCache(df,x,fx,c1,c2,fdtype,returntype,inplace)
130+
cache = GradientCache(df,x,fdtype,returntype,inplace)
112131
finite_difference_gradient!(df,f,x,cache)
113132
end
114133

@@ -119,7 +138,7 @@ function finite_difference_gradient!(df, f, x, fdtype::Type{T1}=Val{:central},
119138
c2::Union{Void,AbstractArray{<:Number}}=nothing,
120139
) where {T1,T2,T3}
121140

122-
cache = GradientCache(df,x,fx,c1,c2,fdtype,returntype,inplace)
141+
cache = GradientCache(df,x,fdtype,returntype,inplace)
123142
finite_difference_gradient!(df,f,x,cache)
124143
end
125144

@@ -230,8 +249,10 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
230249
fx0 = f(x)
231250
x[i] += epsilon
232251
dfi = (f(x) - fx0) / epsilon
252+
@show dfi
233253
x[i] = x_old
234254
end
255+
235256
df[i] = real(dfi)
236257
if eltype(df)<:Complex
237258
if eltype(x)<:Complex

test/finitedifftests.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using DiffEqDiffTools, Base.Test
12

23
# TODO: add tests for GPUArrays
34
# TODO: add tests for DEDataArrays
@@ -162,9 +163,9 @@ x = rand(2)
162163
fx = f(x)
163164
df = zeros(2)
164165
df_ref = [2., 2*x[2]]
165-
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
166-
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
167-
complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:complex})
166+
forward_cache = DiffEqDiffTools.GradientCache(df,x,Val{:forward})
167+
central_cache = DiffEqDiffTools.GradientCache(df,x,Val{:central})
168+
complex_cache = DiffEqDiffTools.GradientCache(df,x,Val{:complex})
168169

169170
@time @testset "Gradient of f:vector->scalar real-valued tests" begin
170171
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
@@ -185,8 +186,8 @@ x = x + im*x
185186
fx = f(x)
186187
df = zeros(x)
187188
df_ref = conj([2.0+2.0*im, 2.0*x[2]])
188-
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
189-
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
189+
forward_cache = DiffEqDiffTools.GradientCache(df,x,Val{:forward})
190+
central_cache = DiffEqDiffTools.GradientCache(df,x,Val{:central})
190191

191192
@time @testset "Gradient of f : C^N -> C tests" begin
192193
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
@@ -204,8 +205,8 @@ x = ones(2) * (1 + im)
204205
fx = f(x)
205206
df = zeros(x)
206207
df_ref = 2*x
207-
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
208-
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
208+
forward_cache = DiffEqDiffTools.GradientCache(df,x,Val{:forward})
209+
central_cache = DiffEqDiffTools.GradientCache(df,x,Val{:central})
209210

210211
@time @testset "Gradient of f : C^N -> R tests" begin
211212
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
@@ -223,8 +224,8 @@ x = ones(2)
223224
fx = f(x)
224225
df = zeros(eltype(fx), size(x))
225226
df_ref = [2.0, -im*2*x[2]]
226-
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward},eltype(df))
227-
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central},eltype(df))
227+
forward_cache = DiffEqDiffTools.GradientCache(df,x,Val{:forward},eltype(df))
228+
central_cache = DiffEqDiffTools.GradientCache(df,x,Val{:central},eltype(df))
228229

229230
@time @testset "Gradient of f : R^N -> C tests" begin
230231
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}, eltype(df)), df_ref) < 1e-4
@@ -243,9 +244,10 @@ fx = zeros(2)
243244
f(fx,x)
244245
df = zeros(2)
245246
df_ref = [cos(x), -sin(x)]
246-
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
247-
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
248-
complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:complex})
247+
forward_cache = DiffEqDiffTools.GradientCache(df,x,Val{:forward})
248+
central_cache = DiffEqDiffTools.GradientCache(df,x,Val{:central})
249+
complex_cache = DiffEqDiffTools.GradientCache(df,x,Val{:complex})
250+
249251

250252
@time @testset "Gradient of f:scalar->vector real-valued tests" begin
251253
@test_broken err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
@@ -268,8 +270,8 @@ fx = zeros(typeof(x), 2)
268270
f(fx,x)
269271
df = zeros(fx)
270272
df_ref = [cos(x), -sin(x)]
271-
forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forward})
272-
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
273+
forward_cache = DiffEqDiffTools.GradientCache(df,x,Val{:forward})
274+
central_cache = DiffEqDiffTools.GradientCache(df,x,Val{:central})
273275

274276
@time @testset "Gradient of f:vector->scalar complex-valued tests" begin
275277
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}, eltype(x), Val{true}, fx), df_ref) < 1e-4

0 commit comments

Comments
 (0)