Skip to content

Commit 38d1817

Browse files
Merge pull request #433 from simulutions/configurable-extrapolation-method-in-invert_integral()
Add optional extrapolation methods
2 parents ccd89d7 + e4c3306 commit 38d1817

5 files changed

+61
-11
lines changed

src/integral_inverses.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <:
4141
extrapolation_right::ExtrapolationType.T
4242
iguesser::Guesser{tType}
4343
itp::itpType
44-
function LinearInterpolationIntInv(u, t, A)
44+
function LinearInterpolationIntInv(u, t, A, extrapolation_left, extrapolation_right)
4545
new{typeof(u), typeof(t), typeof(A), eltype(u)}(
46-
u, t, A.extrapolation_left, A.extrapolation_right, Guesser(t), A)
46+
u, t, extrapolation_left, extrapolation_right, Guesser(t), A)
4747
end
4848
end
4949

@@ -57,9 +57,14 @@ function get_I(A::AbstractInterpolation)
5757
I
5858
end
5959

60-
function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}})
60+
function invert_integral(
61+
A::LinearInterpolation{<:AbstractVector{<:Number}};
62+
extrapolation_left::ExtrapolationType.T = A.extrapolation_left,
63+
extrapolation_right::ExtrapolationType.T = A.extrapolation_right)
6164
!invertible_integral(A) && throw(IntegralNotInvertibleError())
62-
return LinearInterpolationIntInv(A.t, get_I(A), A)
65+
66+
return LinearInterpolationIntInv(
67+
A.t, get_I(A), A, extrapolation_left, extrapolation_right)
6368
end
6469

6570
function _interpolate(
@@ -92,9 +97,10 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <:
9297
extrapolation_right::ExtrapolationType.T
9398
iguesser::Guesser{tType}
9499
itp::itpType
95-
function ConstantInterpolationIntInv(u, t, A)
100+
function ConstantInterpolationIntInv(
101+
u, t, A, extrapolation_left, extrapolation_right)
96102
new{typeof(u), typeof(t), typeof(A), eltype(u)}(
97-
u, t, A.extrapolation_left, A.extrapolation_right, Guesser(t), A
103+
u, t, extrapolation_left, extrapolation_right, Guesser(t), A
98104
)
99105
end
100106
end
@@ -103,9 +109,12 @@ function invertible_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}
103109
return all(A.u .> 0)
104110
end
105111

106-
function invert_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}})
112+
function invert_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}};
113+
extrapolation_left::ExtrapolationType.T = A.extrapolation_left,
114+
extrapolation_right::ExtrapolationType.T = A.extrapolation_right)
107115
!invertible_integral(A) && throw(IntegralNotInvertibleError())
108-
return ConstantInterpolationIntInv(A.t, get_I(A), A)
116+
return ConstantInterpolationIntInv(
117+
A.t, get_I(A), A, extrapolation_left, extrapolation_right)
109118
end
110119

111120
function _interpolate(

src/interpolation_caches.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,8 +1220,8 @@ function BSplineApprox(
12201220
end
12211221
for k in 2:(n - 1)
12221222
q[ax_u...,
1223-
k] = u[ax_u..., k] - sc[k, 1] * u[ax_u..., 1] -
1224-
sc[k, h] * u[ax_u..., end]
1223+
k] = u[ax_u..., k] - sc[k, 1] * u[ax_u..., 1] -
1224+
sc[k, h] * u[ax_u..., end]
12251225
end
12261226
Q = Array{T, N}(undef, size(u)[1:(end - 1)]..., h - 2)
12271227
for i in 2:(h - 1)

src/interpolation_utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ function cumulative_integral(A::AbstractInterpolation{<:Number}, cache_parameter
191191
Base.require_one_based_indexing(A.u)
192192
idxs = cache_parameters ? (1:(length(A.t) - 1)) : (1:0)
193193
return cumsum(_integral(A, idx, t1, t2)
194-
for (idx, t1, t2) in zip(idxs, @view(A.t[begin:(end - 1)]), @view(A.t[(begin + 1):end])))
194+
for (idx, t1, t2) in
195+
zip(idxs, @view(A.t[begin:(end - 1)]), @view(A.t[(begin + 1):end])))
195196
end
196197

197198
function get_parameters(A::LinearInterpolation, idx)

test/integral_inverse_tests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,40 @@ function test_integral_inverses(method; args = [], kwargs = [])
2121
@test @inferred(A(ts[37])) == A(ts[37])
2222
end
2323

24+
function test_integral_inverse_extrapolation()
25+
# Linear function with constant extrapolation
26+
t = collect(1:4)
27+
u = [1.0, 2.0, 3.0, 4.0]
28+
A = LinearInterpolation(u, t, extrapolation = ExtrapolationType.Constant)
29+
30+
A_intinv = invert_integral(A, extrapolation_left = ExtrapolationType.Extension,
31+
extrapolation_right = ExtrapolationType.Extension)
32+
33+
# for a linear function, the integral is quadratic
34+
# but the constant extrapolation part is linear.
35+
area_0_to_4 = 0.5 * 4.0^2
36+
area_4_to_5 = 4.0
37+
area = area_0_to_4 + area_4_to_5
38+
39+
@test A_intinv(area) 5.0
40+
end
41+
42+
function test_integral_inverse_const_extrapolation()
43+
# Constant function with constant extrapolation
44+
t = collect(1:4)
45+
u = [1.0, 1.0, 1.0, 1.0]
46+
A = ConstantInterpolation(u, t, extrapolation = ExtrapolationType.Extension)
47+
48+
A_intinv = invert_integral(A, extrapolation_left = ExtrapolationType.Extension,
49+
extrapolation_right = ExtrapolationType.Extension)
50+
51+
area_0_to_4 = 1.0 * (4.0 - 1.0)
52+
area_4_to_5 = 1.0
53+
area = area_0_to_4 + area_4_to_5
54+
55+
@test A_intinv(area) 5.0
56+
end
57+
2458
@testset "Linear Interpolation" begin
2559
t = collect(1:5)
2660
u = [1.0, 1.0, 2.0, 4.0, 3.0]
@@ -29,6 +63,8 @@ end
2963
u = [1.0, -1.0, 2.0, 4.0, 3.0]
3064
A = LinearInterpolation(u, t)
3165
@test_throws DataInterpolations.IntegralNotInvertibleError invert_integral(A)
66+
67+
test_integral_inverse_extrapolation()
3268
end
3369

3470
@testset "Constant Interpolation" begin
@@ -40,6 +76,8 @@ end
4076
u = [1.0, -1.0, 2.0, 4.0, 3.0]
4177
A = ConstantInterpolation(u, t)
4278
@test_throws DataInterpolations.IntegralNotInvertibleError invert_integral(A)
79+
80+
test_integral_inverse_const_extrapolation()
4381
end
4482

4583
t = collect(1:5)

test/interpolation_tests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,7 @@ end
10271027
ut1 = Float32[0.1, 0.2, 0.3, 0.4, 0.5]
10281028
ut2 = Float64[0.1, 0.2, 0.3, 0.4, 0.5]
10291029
for u in (ut1, ut2), t in (ut1, ut2)
1030+
10301031
interp = @inferred(LinearInterpolation(ut1, ut2))
10311032
for xs in (u, t)
10321033
ys = @inferred(interp(xs))
@@ -1109,6 +1110,7 @@ f_cubic_spline = c -> square(CubicSpline, c)
11091110
iszero_allocations(u, t) = iszero(@allocated(DataInterpolations.munge_data(u, t)))
11101111

11111112
for T in (String, Union{String, Missing}), dims in 1:3
1113+
11121114
_u0 = convert(Array{T}, reshape(u0, ntuple(i -> i == dims ? 3 : 1, dims)))
11131115

11141116
u, t = @inferred(DataInterpolations.munge_data(_u0, t0))

0 commit comments

Comments
 (0)