Skip to content

Commit 2527aef

Browse files
Merge pull request #407 from SouthEndMusic/fix_type_instability
Fix type instabilities from extrapolation
2 parents 7cfc72d + 8f70cdf commit 2527aef

File tree

4 files changed

+40
-23
lines changed

4 files changed

+40
-23
lines changed

src/derivatives.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,28 @@ function _extrapolate_derivative_left(A, t, order)
2020
elseif extrapolation_left == ExtrapolationType.Constant
2121
zero(first(A.u) / one(A.t[1]))
2222
elseif extrapolation_left == ExtrapolationType.Linear
23-
(order == 1) ? derivative(A, first(A.t)) : zero(first(A.u) / one(A.t[1]))
23+
_derivative(A, first(A.t), 1)
24+
zero(first(A.u) / one(A.t[1]))
25+
(order == 1) ? _derivative(A, first(A.t), 1) : zero(first(A.u) / one(A.t[1]))
2426
elseif extrapolation_left == ExtrapolationType.Extension
25-
iguess = A.iguesser
26-
(order == 1) ? _derivative(A, t, iguess) :
27+
(order == 1) ? _derivative(A, t, length(A.t)) :
2728
ForwardDiff.derivative(t -> begin
28-
_derivative(A, t, iguess)
29+
_derivative(A, t, length(A.t))
2930
end, t)
3031
elseif extrapolation_left == ExtrapolationType.Periodic
3132
t_, _ = transformation_periodic(A, t)
32-
derivative(A, t_, order)
33+
(order == 1) ? _derivative(A, t_, A.iguesser) :
34+
ForwardDiff.derivative(t -> begin
35+
_derivative(A, t, A.iguesser)
36+
end, t_)
3337
else
3438
# extrapolation_left == ExtrapolationType.Reflective
3539
t_, n = transformation_reflective(A, t)
36-
isodd(n) ? -derivative(A, t_, order) : derivative(A, t_, order)
40+
sign = isodd(n) ? -1 : 1
41+
(order == 1) ? sign * _derivative(A, t_, A.iguesser) :
42+
ForwardDiff.derivative(t -> begin
43+
sign * _derivative(A, t, A.iguesser)
44+
end, t_)
3745
end
3846
end
3947

@@ -44,20 +52,27 @@ function _extrapolate_derivative_right(A, t, order)
4452
elseif extrapolation_right == ExtrapolationType.Constant
4553
zero(first(A.u) / one(A.t[1]))
4654
elseif extrapolation_right == ExtrapolationType.Linear
47-
(order == 1) ? derivative(A, last(A.t)) : zero(first(A.u) / one(A.t[1]))
55+
(order == 1) ? _derivative(A, last(A.t), length(A.t)) :
56+
zero(first(A.u) / one(A.t[1]))
4857
elseif extrapolation_right == ExtrapolationType.Extension
49-
iguess = A.iguesser
50-
(order == 1) ? _derivative(A, t, iguess) :
58+
(order == 1) ? _derivative(A, t, length(A.t)) :
5159
ForwardDiff.derivative(t -> begin
52-
_derivative(A, t, iguess)
60+
_derivative(A, t, length(A.t))
5361
end, t)
5462
elseif extrapolation_right == ExtrapolationType.Periodic
5563
t_, _ = transformation_periodic(A, t)
56-
derivative(A, t_, order)
64+
(order == 1) ? _derivative(A, t_, A.iguesser) :
65+
ForwardDiff.derivative(t -> begin
66+
_derivative(A, t, A.iguesser)
67+
end, t_)
5768
else
58-
# extrapolation_right == ExtrapolationType.Reflective
69+
# extrapolation_left == ExtrapolationType.Reflective
5970
t_, n = transformation_reflective(A, t)
60-
iseven(n) ? -derivative(A, t_, order) : derivative(A, t_, order)
71+
sign = iseven(n) ? -1 : 1
72+
(order == 1) ? sign * _derivative(A, t_, A.iguesser) :
73+
ForwardDiff.derivative(t -> begin
74+
sign * _derivative(A, t, A.iguesser)
75+
end, t_)
6176
end
6277
end
6378

@@ -299,4 +314,4 @@ function _derivative(
299314
out += Δt₀^2 *
300315
(3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃)
301316
out
302-
end
317+
end

src/integral_inverses.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ Creates the inverted integral interpolation object from the given interpolation.
1313
1414
- `A`: interpolation object satisfying the above requirements
1515
"""
16-
invert_integral(A::AbstractInterpolation) = throw(IntegralInverseNotFoundError())
16+
invert_integral(::AbstractInterpolation) = throw(IntegralInverseNotFoundError())
1717

18-
_integral(A::AbstractIntegralInverseInterpolation, idx, t) = throw(IntegralNotFoundError())
18+
_integral(::AbstractIntegralInverseInterpolation, idx, t) = throw(IntegralNotFoundError())
1919

2020
function _derivative(A::AbstractIntegralInverseInterpolation, t::Number, iguess)
21-
inv(A.itp(A(t)))
21+
inv(A.itp(_interpolate(A, t, iguess)))
2222
end
2323

2424
"""
@@ -119,4 +119,4 @@ function _interpolate(
119119
idx_ = get_idx(A, t, idx; side = :first, lb = 1, ub_shift = 0)
120120
end
121121
A.u[idx] + (t - A.t[idx]) / A.itp.u[idx_]
122-
end
122+
end

src/interpolation_methods.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ function _extrapolate_left(A, t)
1313
if extrapolation_left == ExtrapolationType.None
1414
throw(LeftExtrapolationError())
1515
elseif extrapolation_left == ExtrapolationType.Constant
16-
slope = derivative(A, first(A.t))
16+
slope = _derivative(A, first(A.t), 1)
1717
first(A.u) + zero(slope * t)
1818
elseif extrapolation_left == ExtrapolationType.Linear
19-
slope = derivative(A, first(A.t))
19+
slope = _derivative(A, first(A.t), 1)
2020
first(A.u) + slope * (t - first(A.t))
2121
else
2222
_extrapolate_other(A, t, extrapolation_left)
@@ -28,10 +28,10 @@ function _extrapolate_right(A, t)
2828
if extrapolation_right == ExtrapolationType.None
2929
throw(RightExtrapolationError())
3030
elseif extrapolation_right == ExtrapolationType.Constant
31-
slope = derivative(A, last(A.t))
31+
slope = _derivative(A, last(A.t), length(A.t))
3232
last(A.u) + zero(slope * t)
3333
elseif extrapolation_right == ExtrapolationType.Linear
34-
slope = derivative(A, last(A.t))
34+
slope = _derivative(A, last(A.t), length(A.t))
3535
last(A.u) + slope * (t - last(A.t))
3636
else
3737
_extrapolate_other(A, t, extrapolation_right)
@@ -335,4 +335,4 @@ function _interpolate(
335335
c₁, c₂, c₃ = get_parameters(A, idx)
336336
out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁))
337337
out
338-
end
338+
end

test/integral_inverse_tests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ function test_integral_inverses(method; args = [], kwargs = [])
1717
adiff = derivative(A_intinv, I)
1818
@test cdiff adiff
1919
end
20+
21+
@test @inferred(A(ts[37])) == A(ts[37])
2022
end
2123

2224
@testset "Linear Interpolation" begin

0 commit comments

Comments
 (0)