Skip to content

Commit e16f147

Browse files
committed
Some restructuring and DiffEq wrapper support
1 parent d5817b5 commit e16f147

File tree

2 files changed

+56
-28
lines changed

2 files changed

+56
-28
lines changed

src/finitediff.jl

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ end
1414
eps_cbrt * max(one(T), abs(x))
1515
end
1616

17-
@inline function compute_epsilon{T<:Complex}(::Type{Val{:Complex}}, x::T)
18-
eps(real(x))
17+
@inline function compute_epsilon{T<:Real}(::Type{Val{:complex}}, x::T, ::Union{Void,T}=nothing)
18+
eps(x)
1919
end
2020

2121
@inline function compute_epsilon_factor{T<:Real}(fdtype::DataType, ::Type{T})
@@ -34,28 +34,63 @@ Compute the derivative df of a real-valued callable f on a collection of points
3434
Generic fallbacks for AbstractArrays that are not StridedArrays.
3535
# TODO: test the fallbacks
3636
=#
37-
function finite_difference{T<:Real}(f, x::AbstractArray{T}, fdtype::DataType, fx::Union{Void,AbstractArray{T}}=nothing)
37+
function finite_difference{T<:Real}(f, x::AbstractArray{T}, fdtype::DataType, fx::Union{Void,AbstractArray{T}}=nothing, funtype::DataType=Val{:Default})
3838
df = zeros(T, size(x))
39-
finite_difference!(df, f, x, fdtype, fx)
39+
finite_difference!(df, f, x, fdtype, fx, funtype)
4040
end
4141

42-
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::AbstractArray{T}, ::Type{Val{:forward}}, f_x::AbstractArray{T}=f.(x))
43-
epsilon_factor = compute_epsilon_factor(Val{:forward}, T)
44-
@. epsilon = compute_epsilon(Val{:forward}, x)
45-
@. df = (f(x+epsilon) - f_x) / epsilon
42+
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::AbstractArray{T}, fdtype::DataType, fx::Union{Void,AbstractArray{T}}, ::Type{Val{:Default}})
43+
epsilon_factor = compute_epsilon_factor(fdtype, T)
44+
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
45+
if fdtype == Val{:forward}
46+
if typeof(fx) == Void
47+
@. df = (f(x+epsilon) - f(x)) / epsilon
48+
else
49+
@. df = (f(x+epsilon) - fx) / epsilon
50+
end
51+
elseif fdtype == Val{:central}
52+
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
53+
end
54+
df
4655
end
4756

48-
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::AbstractArray{T}, ::Type{Val{:central}}, ::Union{Void,AbstractArray{T}}=nothing)
49-
epsilon_factor = compute_epsilon_factor(Val{:central}, T)
50-
@. epsilon = compute_epsilon(Val{:central}, x, epsilon_factor)
51-
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
57+
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::AbstractArray{T}, fdtype::DataType, fx::Union{Void,AbstractArray{T}}, ::Type{Val{:DiffEqDerivativeWrapper}})
58+
epsilon_factor = compute_epsilon_factor(fdtype, T)
59+
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
60+
error("Not implemented yet.")
61+
62+
if fdtype == Val{:forward}
63+
if typeof(fx) == Void
64+
65+
else
66+
67+
end
68+
elseif fdtype == Val{:central}
69+
70+
end
71+
df
72+
end
73+
74+
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::T, fdtype::DataType, fx::AbstractArray{T}, ::Type{Val{:DiffEqDerivativeWrapper}})
75+
epsilon = compute_epsilon(fdtype, x)
76+
fx1 = f.fx1
77+
if fdtype == Val{:forward}
78+
f(fx, x)
79+
f(fx1, x+epsilon)
80+
@. df = (fx1 - fx) / epsilon
81+
elseif fdtype == Val{:central}
82+
f(fx, x-epsilon)
83+
f(fx1, x+epsilon)
84+
@. df = (fx1 - fx) / (2 * epsilon)
85+
end
86+
df
5287
end
5388

5489
#=
5590
Compute the derivative df of a real-valued callable f on a collection of points x.
5691
Optimized implementations for StridedArrays.
5792
=#
58-
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:central}}, ::Union{Void,StridedArray{T}}=nothing)
93+
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:central}}, ::Union{Void,StridedArray{T}}, ::Type{Val{:Default}})
5994
epsilon_factor = compute_epsilon_factor(Val{:central}, T)
6095
@inbounds for i in 1 : length(x)
6196
epsilon = compute_epsilon(Val{:central}, x[i], epsilon_factor)
@@ -66,17 +101,7 @@ function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T},
66101
df
67102
end
68103

69-
function finite_difference{T<:Real}(f, x::StridedArray{T}, ::Type{Val{:forward}}, fx::Union{Void,StridedArray{T}})
70-
df = zeros(T, size(x))
71-
if typeof(fx) == Void
72-
finite_difference!(df, f, x, Val{:forward})
73-
else
74-
finite_difference!(df, f, x, Val{:forward}, fx)
75-
end
76-
df
77-
end
78-
79-
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}})
104+
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}}, ::Void, ::Type{Val{:Default}})
80105
epsilon_factor = compute_epsilon_factor(Val{:forward}, T)
81106
@inbounds for i in 1 : length(x)
82107
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
@@ -87,7 +112,7 @@ function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T},
87112
df
88113
end
89114

90-
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}}, fx::StridedArray{T})
115+
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}}, fx::StridedArray{T}, ::Type{Val{:Default}})
91116
epsilon_factor = compute_epsilon_factor(Val{:forward}, T)
92117
@inbounds for i in 1 : length(x)
93118
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)

test/finitedifftests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ df_ref = cos.(x)
66
# TODO: add tests for non-StridedArrays and with more complicated functions
77

88
# derivative tests
9-
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}) - df_ref)) < 1e-4
10-
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}, y) - df_ref)) < 1e-4
11-
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:central}) - df_ref)) < 1e-8
9+
@test maximum(abs.(DiffEqDiffTools.finite_difference(sin, x, Val{:forward}) - df_ref)) < 1e-4
10+
@test maximum(abs.(DiffEqDiffTools.finite_difference(sin, x, Val{:forward}, y) - df_ref)) < 1e-4
11+
@test maximum(abs.(DiffEqDiffTools.finite_difference(sin, x, Val{:central}) - df_ref)) < 1e-8
12+
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}, nothing, Val{:Default}) - df_ref)) < 1e-4
13+
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}, y, Val{:Default}) - df_ref)) < 1e-4
14+
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:central}, nothing, Val{:Default}) - df_ref)) < 1e-8
1215

1316
# Jacobian tests
1417
using Calculus

0 commit comments

Comments
 (0)