Skip to content

Commit af6c8e3

Browse files
committed
Basic finite difference functionality - performant derivatives of real-valued callables
1 parent 2781441 commit af6c8e3

File tree

4 files changed

+138
-3
lines changed

4 files changed

+138
-3
lines changed

src/DiffEqDiffTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ __precompile__()
22

33
module DiffEqDiffTools
44

5-
# package code goes here
5+
include("finitediff.jl")
66

77
end # module

src/finitediff.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#=
2+
Very heavily inspired by Calculus.jl, but with an emphasis on performance and DiffEq API convenience.
3+
=#
4+
5+
#=
6+
Compute the finite difference interval epsilon.
7+
Reference: Numerical Recipes, chapter 5.7.
8+
=#
9+
@inline function compute_epsilon{T<:Real}(::Type{Val{:forward}}, x::T, eps_sqrt::T=sqrt(eps(T)))
10+
eps_sqrt * max(one(T), abs(x))
11+
end
12+
13+
@inline function compute_epsilon{T<:Real}(::Type{Val{:central}}, x::T, eps_cbrt::T=cbrt(eps(T)))
14+
eps_cbrt * max(one(T), abs(x))
15+
end
16+
17+
@inline function compute_epsilon{T<:Complex}(::Type{Val{:complex}}, x::T)
18+
eps(real(x))
19+
end
20+
21+
22+
#=
23+
Compute the derivative df of a real-valued callable f on a collection of points x.
24+
Generic fallbacks for AbstractArrays that are not StridedArrays.
25+
=#
26+
function finite_difference{T<:Real}(f, x::AbstractArray{T}, ::Type{Val{:central}}, ::Union{Void,AbstractArray{T}}=nothing)
27+
df = zeros(T, size(x))
28+
finite_difference!(df, f, x, Val{:central})
29+
end
30+
31+
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::AbstractArray{T}, ::Type{Val{:central}}, ::Union{Void,AbstractArray{T}}=nothing)
32+
eps_sqrt = sqrt(eps(T))
33+
epsilon = compute_epsilon.(Val{:central}, x, eps_sqrt)
34+
@. df = (f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
35+
end
36+
37+
function finite_difference{T<:Real}(f, x::AbstractArray{T}, ::Type{Val{:forward}}, f_x::AbstractArray{T}=f.(x))
38+
df = zeros(T, size(x))
39+
finite_difference!(df, f, x, Val{:forward}, f_x)
40+
end
41+
42+
function finite_difference!{T<:Real}(df::AbstractArray{T}, f, x::AbstractArray{T}, ::Type{Val{:forward}}, f_x::AbstractArray{T}=f.(x))
43+
eps_cbrt = cbrt(eps(T))
44+
epsilon = compute_epsilon.(Val{:forward}, x, eps_cbrt)
45+
@. df = (f(x+epsilon) - f_x) / epsilon
46+
end
47+
48+
49+
#=
50+
Compute the derivative df of a real-valued callable f on a collection of points x.
51+
Optimized implementations for StridedArrays.
52+
=#
53+
function finite_difference{T<:Real}(f, x::StridedArray{T}, ::Type{Val{:central}}, ::Union{Void,StridedArray{T}}=nothing)
54+
df = zeros(T, size(x))
55+
finite_difference!(df, f, x, Val{:central})
56+
end
57+
58+
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:central}}, ::Union{Void,StridedArray{T}}=nothing)
59+
eps_sqrt = sqrt(eps(T))
60+
@inbounds for i in 1 : length(x)
61+
epsilon = compute_epsilon(Val{:central}, x[i], eps_sqrt)
62+
epsilon_double_inv = one(T) / (2*epsilon)
63+
x_plus, x_minus = x[i]+epsilon, x[i]-epsilon
64+
df[i] = (f(x_plus) - f(x_minus)) * epsilon_double_inv
65+
end
66+
df
67+
end
68+
69+
function finite_difference{T<:Real}(f, x::StridedArray{T}, ::Type{Val{:forward}}, fx::Union{Void,StridedArray{T}})
70+
df = zeros(T, size(x))
71+
if typeof(fx) == Void
72+
finite_difference!(df, f, x, Val{:forward})
73+
else
74+
finite_difference!(df, f, x, Val{:forward}, fx)
75+
end
76+
df
77+
end
78+
79+
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}})
80+
eps_cbrt = cbrt(eps(T))
81+
@fastmath @inbounds for i in 1 : length(x)
82+
epsilon = compute_epsilon(Val{:forward}, x[i], eps_cbrt)
83+
epsilon_inv = one(T) / epsilon
84+
x_plus = x[i] + epsilon
85+
df[i] = (f(x_plus) - f(x[i])) * epsilon_inv
86+
end
87+
df
88+
end
89+
90+
function finite_difference!{T<:Real}(df::StridedArray{T}, f, x::StridedArray{T}, ::Type{Val{:forward}}, fx::StridedArray{T})
91+
eps_cbrt = cbrt(eps(T))
92+
@fastmath @inbounds for i in 1 : length(x)
93+
epsilon = compute_epsilon(Val{:forward}, x[i], eps_cbrt)
94+
epsilon_inv = one(T) / epsilon
95+
x_plus = x[i] + epsilon
96+
df[i] = (f(x_plus) - fx[i]) * epsilon_inv
97+
end
98+
df
99+
end
100+
101+
#=
102+
Compute the derivative df of a real-valued callable f on a collection of points x.
103+
Single point implementations.
104+
=#
105+
function finite_difference{T<:Real}(f, x::T, t::DataType, f_x::Union{Void,T}=nothing)
106+
epsilon = compute_epsilon(t, x)
107+
finite_difference_kernel(f, x, t, epsilon, f_x)
108+
end
109+
110+
@inline function finite_difference_kernel{T<:Real}(f, x::T, ::Type{Val{:forward}}, epsilon::T, f_x::Union{Void,T})
111+
if typeof(f_x) == Void
112+
return (f(x+epsilon) - f(x)) / epsilon
113+
else
114+
return (f(x+epsilon) - f_x) / epsilon
115+
end
116+
end
117+
118+
@inline function finite_difference_kernel{T<:Real}(f, x::T, ::Type{Val{:central}}, epsilon::T, ::Union{Void,T}=nothing)
119+
(f(x+epsilon) - f(x-epsilon)) / (2 * epsilon)
120+
end
121+
122+
# TODO: derivatives for complex-valued callables
123+
124+
125+
#=
126+
Compute the Jacobian matrix of a real-valued callable f.
127+
=#
128+
# TODO

test/finitedifftests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
x = collect(linspace(-2π, 2π, 100))
2+
y = sin.(x)
3+
df = zeros(100)
4+
df_ref = cos.(x)
5+
6+
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:central}) - df_ref)) < 1e-8
7+
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}) - df_ref)) < 1e-4
8+
@test maximum(abs.(DiffEqDiffTools.finite_difference!(df, sin, x, Val{:forward}, y) - df_ref)) < 1e-4

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using DiffEqDiffTools
22
using Base.Test
33

4-
# write your own tests here
5-
@test 1 == 2
4+
include("finitedifftests.jl")

0 commit comments

Comments
 (0)