Skip to content

Commit 5391a3c

Browse files
Merge pull request #393 from SciML/dw/cumulative_integral
Improve `cumulative_integral`
2 parents b983363 + 66f8b17 commit 5391a3c

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

src/interpolation_utils.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,12 @@ function get_idx(A::AbstractInterpolation, t, iguess::Union{<:Integer, Guesser};
189189
end
190190
end
191191

192-
function cumulative_integral(A, cache_parameters)
193-
if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number, Number})
194-
integral_values = _integral.(
195-
Ref(A), 1:(length(A.t) - 1), A.t[1:(end - 1)], A.t[2:end])
196-
cumsum(integral_values)
197-
else
198-
promote_type(eltype(A.u), eltype(A.t))[]
199-
end
192+
cumulative_integral(::AbstractInterpolation, ::Bool) = nothing
193+
function cumulative_integral(A::AbstractInterpolation{<:Number}, cache_parameters::Bool)
194+
Base.require_one_based_indexing(A.u)
195+
idxs = cache_parameters ? (1:(length(A.t) - 1)) : (1:0)
196+
return cumsum(_integral(A, idx, t1, t2)
197+
for (idx, t1, t2) in zip(idxs, @view(A.t[begin:(end - 1)]), @view(A.t[(begin + 1):end])))
200198
end
201199

202200
function get_parameters(A::LinearInterpolation, idx)

test/integral_tests.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using DataInterpolations: integral
44
using Optim, ForwardDiff
55
using RegularizationTools
66
using StableRNGs
7+
using Unitful
78

89
function test_integral(method; args = [], kwargs = [], name::String)
910
func = method(args...; kwargs..., extrapolation_left = ExtrapolationType.Extension,
@@ -213,3 +214,22 @@ end
213214
@test_throws DataInterpolations.IntegralNotFoundError integral(A, 1.0, 100.0)
214215
@test_throws DataInterpolations.IntegralNotFoundError integral(A, 50.0)
215216
end
217+
218+
# issue #385
219+
@testset "Integrals with unitful numbers" begin
220+
u = rand(5)u"m"
221+
A = ConstantInterpolation(u, (1:5)u"s")
222+
@test @inferred(integral(A, 4u"s")) sum(u[1:3]) * u"s"
223+
end
224+
225+
@testset "cumulative_integral" begin
226+
A = ConstantInterpolation(["A", "B", "C"], [0.0, 0.25, 0.75])
227+
for cache_parameter in (true, false)
228+
@test @inferred(DataInterpolations.cumulative_integral(A, cache_parameter)) ===
229+
nothing
230+
end
231+
232+
A = ConstantInterpolation([3.1, 2.5, 4.7], [0.0, 0.25, 0.75])
233+
@test @inferred(DataInterpolations.cumulative_integral(A, false)) == Float64[]
234+
@test @inferred(DataInterpolations.cumulative_integral(A, true)) == [0.775, 2.025]
235+
end

0 commit comments

Comments
 (0)