Skip to content

Fix type instabilities from extrapolation #407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@ function _extrapolate_derivative_left(A, t, order)
elseif extrapolation_left == ExtrapolationType.Constant
zero(first(A.u) / one(A.t[1]))
elseif extrapolation_left == ExtrapolationType.Linear
(order == 1) ? derivative(A, first(A.t)) : zero(first(A.u) / one(A.t[1]))
_derivative(A, first(A.t), 1)
zero(first(A.u) / one(A.t[1]))
(order == 1) ? _derivative(A, first(A.t), 1) : zero(first(A.u) / one(A.t[1]))
elseif extrapolation_left == ExtrapolationType.Extension
iguess = A.iguesser
(order == 1) ? _derivative(A, t, iguess) :
(order == 1) ? _derivative(A, t, length(A.t)) :
ForwardDiff.derivative(t -> begin
_derivative(A, t, iguess)
_derivative(A, t, length(A.t))
end, t)
elseif extrapolation_left == ExtrapolationType.Periodic
t_, _ = transformation_periodic(A, t)
derivative(A, t_, order)
(order == 1) ? _derivative(A, t_, A.iguesser) :
ForwardDiff.derivative(t -> begin
_derivative(A, t, A.iguesser)
end, t_)
else
# extrapolation_left == ExtrapolationType.Reflective
t_, n = transformation_reflective(A, t)
isodd(n) ? -derivative(A, t_, order) : derivative(A, t_, order)
sign = isodd(n) ? -1 : 1
(order == 1) ? sign * _derivative(A, t_, A.iguesser) :
ForwardDiff.derivative(t -> begin
sign * _derivative(A, t, A.iguesser)
end, t_)
end
end

Expand All @@ -44,20 +52,27 @@ function _extrapolate_derivative_right(A, t, order)
elseif extrapolation_right == ExtrapolationType.Constant
zero(first(A.u) / one(A.t[1]))
elseif extrapolation_right == ExtrapolationType.Linear
(order == 1) ? derivative(A, last(A.t)) : zero(first(A.u) / one(A.t[1]))
(order == 1) ? _derivative(A, last(A.t), length(A.t)) :
zero(first(A.u) / one(A.t[1]))
elseif extrapolation_right == ExtrapolationType.Extension
iguess = A.iguesser
(order == 1) ? _derivative(A, t, iguess) :
(order == 1) ? _derivative(A, t, length(A.t)) :
ForwardDiff.derivative(t -> begin
_derivative(A, t, iguess)
_derivative(A, t, length(A.t))
end, t)
elseif extrapolation_right == ExtrapolationType.Periodic
t_, _ = transformation_periodic(A, t)
derivative(A, t_, order)
(order == 1) ? _derivative(A, t_, A.iguesser) :
ForwardDiff.derivative(t -> begin
_derivative(A, t, A.iguesser)
end, t_)
else
# extrapolation_right == ExtrapolationType.Reflective
# extrapolation_left == ExtrapolationType.Reflective
t_, n = transformation_reflective(A, t)
iseven(n) ? -derivative(A, t_, order) : derivative(A, t_, order)
sign = iseven(n) ? -1 : 1
(order == 1) ? sign * _derivative(A, t_, A.iguesser) :
ForwardDiff.derivative(t -> begin
sign * _derivative(A, t, A.iguesser)
end, t_)
end
end

Expand Down Expand Up @@ -299,4 +314,4 @@ function _derivative(
out += Δt₀^2 *
(3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃)
out
end
end
8 changes: 4 additions & 4 deletions src/integral_inverses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ Creates the inverted integral interpolation object from the given interpolation.

- `A`: interpolation object satisfying the above requirements
"""
invert_integral(A::AbstractInterpolation) = throw(IntegralInverseNotFoundError())
invert_integral(::AbstractInterpolation) = throw(IntegralInverseNotFoundError())

_integral(A::AbstractIntegralInverseInterpolation, idx, t) = throw(IntegralNotFoundError())
_integral(::AbstractIntegralInverseInterpolation, idx, t) = throw(IntegralNotFoundError())

function _derivative(A::AbstractIntegralInverseInterpolation, t::Number, iguess)
inv(A.itp(A(t)))
inv(A.itp(_interpolate(A, t, iguess)))
end

"""
Expand Down Expand Up @@ -119,4 +119,4 @@ function _interpolate(
idx_ = get_idx(A, t, idx; side = :first, lb = 1, ub_shift = 0)
end
A.u[idx] + (t - A.t[idx]) / A.itp.u[idx_]
end
end
10 changes: 5 additions & 5 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ function _extrapolate_left(A, t)
if extrapolation_left == ExtrapolationType.None
throw(LeftExtrapolationError())
elseif extrapolation_left == ExtrapolationType.Constant
slope = derivative(A, first(A.t))
slope = _derivative(A, first(A.t), 1)
first(A.u) + zero(slope * t)
elseif extrapolation_left == ExtrapolationType.Linear
slope = derivative(A, first(A.t))
slope = _derivative(A, first(A.t), 1)
first(A.u) + slope * (t - first(A.t))
else
_extrapolate_other(A, t, extrapolation_left)
Expand All @@ -28,10 +28,10 @@ function _extrapolate_right(A, t)
if extrapolation_right == ExtrapolationType.None
throw(RightExtrapolationError())
elseif extrapolation_right == ExtrapolationType.Constant
slope = derivative(A, last(A.t))
slope = _derivative(A, last(A.t), length(A.t))
last(A.u) + zero(slope * t)
elseif extrapolation_right == ExtrapolationType.Linear
slope = derivative(A, last(A.t))
slope = _derivative(A, last(A.t), length(A.t))
last(A.u) + slope * (t - last(A.t))
else
_extrapolate_other(A, t, extrapolation_right)
Expand Down Expand Up @@ -335,4 +335,4 @@ function _interpolate(
c₁, c₂, c₃ = get_parameters(A, idx)
out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁))
out
end
end
2 changes: 2 additions & 0 deletions test/integral_inverse_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ function test_integral_inverses(method; args = [], kwargs = [])
adiff = derivative(A_intinv, I)
@test cdiff ≈ adiff
end

@test @inferred(A(ts[37])) == A(ts[37])
end

@testset "Linear Interpolation" begin
Expand Down
Loading