Skip to content

Commit b983363

Browse files
Merge pull request #394 from SciML/dw/get_u
Do not fall back to inplace interpolation method
2 parents 8ead701 + e3cfc01 commit b983363

File tree

3 files changed

+68
-44
lines changed

3 files changed

+68
-44
lines changed

src/DataInterpolations.jl

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,37 +25,29 @@ include("online.jl")
2525
include("show.jl")
2626

2727
(interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t)
28-
2928
function (interp::AbstractInterpolation)(t::AbstractVector)
30-
u = get_u(interp.u, t)
31-
interp(u, t)
32-
end
33-
34-
function get_u(u::AbstractVector, t)
35-
return similar(t, promote_type(eltype(u), eltype(t)))
36-
end
37-
38-
function get_u(u::AbstractVector{<:AbstractVector}, t)
39-
type = promote_type(eltype(eltype(u)), eltype(t))
40-
return [zeros(type, length(first(u))) for _ in eachindex(t)]
41-
end
42-
43-
function get_u(u::AbstractMatrix, t)
44-
type = promote_type(eltype(u), eltype(t))
45-
return zeros(type, (size(u, 1), length(t)))
29+
if interp.u isa AbstractVector
30+
# Return a vector of interpolated values, on for each element in `t`
31+
return map(interp, t)
32+
elseif interp.u isa AbstractArray
33+
# Stack interpolated values if `u` was stored in matrix/... form
34+
return stack(interp, t)
35+
end
4636
end
4737

48-
function (interp::AbstractInterpolation)(u::AbstractMatrix, t::AbstractVector)
49-
@inbounds for i in eachindex(t)
50-
u[:, i] = interp(t[i])
38+
function (interp::AbstractInterpolation)(out::AbstractVector, t::AbstractVector)
39+
if length(out) != length(t)
40+
throw(DimensionMismatch("number of evaluation points and length of the result vector must be equal"))
5141
end
52-
u
42+
map!(interp, out, t)
43+
return out
5344
end
54-
function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector)
55-
@inbounds for i in eachindex(u, t)
56-
u[i] = interp(t[i])
45+
function (interp::AbstractInterpolation)(out::AbstractArray, t::AbstractVector)
46+
if size(out, ndims(out)) != length(t)
47+
throw(DimensionMismatch("number of evaluation points and last dimension of the result array must be equal"))
5748
end
58-
u
49+
map!(interp, eachslice(out; dims = ndims(out)), t)
50+
return out
5951
end
6052

6153
const EXTRAPOLATION_ERROR = "Cannot extrapolate as `extrapolate` keyword passed was `false`"

test/extrapolation_tests.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,16 @@ end
5252
# Left extrapolation
5353
A = ConstantInterpolation(u_un, t_un; extrapolation_left = extrapolation_type)
5454
t_eval = 0.0u"s"
55-
@test A(t_eval) == 1.0u"m"
55+
@test @inferred(A(t_eval)) == 1.0u"m"
56+
@test @inferred(A([t_eval])) == [1.0u"m"]
57+
@test A([t_eval]) isa Vector{typeof(1.0u"m")}
5658

5759
# Right extrapolation
5860
A = ConstantInterpolation(u_un, t_un; extrapolation_right = extrapolation_type)
5961
t_eval = 3.0u"s"
60-
@test A(t_eval) == 2.0u"m"
62+
@test @inferred(A(t_eval)) == 2.0u"m"
63+
@test @inferred(A([t_eval])) == [2.0u"m"]
64+
@test A([t_eval]) isa Vector{typeof(2.0u"m")}
6165
end
6266
end
6367

@@ -68,22 +72,30 @@ end
6872
# Left constant extrapolation
6973
A = LinearInterpolation(u_un, t_un; extrapolation_left = ExtrapolationType.Constant)
7074
t_eval = 0.0u"s"
71-
@test A(t_eval) == 1.0u"m"
75+
@test @inferred(A(t_eval)) == 1.0u"m"
76+
@test @inferred(A([t_eval])) == [1.0u"m"]
77+
@test A([t_eval]) isa Vector{typeof(1.0u"m")}
7278

7379
# Right constant extrapolation
7480
A = LinearInterpolation(u_un, t_un; extrapolation_right = ExtrapolationType.Constant)
7581
t_eval = 3.0u"s"
76-
@test A(t_eval) == 2.0u"m"
82+
@test @inferred(A(t_eval)) == 2.0u"m"
83+
@test @inferred(A([t_eval])) == [2.0u"m"]
84+
@test A([t_eval]) isa Vector{typeof(2.0u"m")}
7785

7886
# Left linear extrapolation
7987
A = LinearInterpolation(u_un, t_un; extrapolation_left = ExtrapolationType.Linear)
8088
t_eval = 0.0u"s"
81-
@test A(t_eval) == 0.0u"m"
89+
@test @inferred(A(t_eval)) == 0.0u"m"
90+
@test @inferred(A([t_eval])) == [0.0u"m"]
91+
@test A([t_eval]) isa Vector{typeof(0.0u"m")}
8292

8393
# Right constant extrapolation
8494
A = LinearInterpolation(u_un, t_un; extrapolation_right = ExtrapolationType.Linear)
8595
t_eval = 3.0u"s"
86-
@test A(t_eval) == 3.0u"m"
96+
@test @inferred(A(t_eval)) == 3.0u"m"
97+
@test @inferred(A([t_eval])) == [3.0u"m"]
98+
@test A([t_eval]) isa Vector{typeof(3.0u"m")}
8799
end
88100

89101
@testset "Linear Interpolation" begin

test/interpolation_tests.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,13 @@ end
525525
itp = ConstantInterpolation([2], [0.0]; extrapolation = ExtrapolationType.Constant)
526526
@test itp(1.0) === 2
527527
@test itp(-1.0) === 2
528+
529+
# Test output type of vector evaluation (issue #388)
530+
u = [2, 3]
531+
t = [0.0, 1.0]
532+
itp = ConstantInterpolation(u, t)
533+
@test @inferred(itp(t)) == itp.(t)
534+
@test typeof(itp(t)) === typeof(itp.(t)) === Vector{Int}
528535
end
529536

530537
@testset "QuadraticSpline Interpolation" begin
@@ -855,33 +862,46 @@ end
855862

856863
@testset "Type of vector returned" begin
857864
# Issue https://github.com/SciML/DataInterpolations.jl/issues/253
858-
t1 = Float32[0.1, 0.2, 0.3, 0.4, 0.5]
859-
t2 = Float64[0.1, 0.2, 0.3, 0.4, 0.5]
860-
interps_and_types = [
861-
(LinearInterpolation(t1, t1), Float32),
862-
(LinearInterpolation(t1, t2), Float32),
863-
(LinearInterpolation(t2, t1), Float64),
864-
(LinearInterpolation(t2, t2), Float64)
865-
]
866-
for i in eachindex(interps_and_types)
867-
@test eltype(interps_and_types[i][1](t1)) == interps_and_types[i][2]
865+
ut1 = Float32[0.1, 0.2, 0.3, 0.4, 0.5]
866+
ut2 = Float64[0.1, 0.2, 0.3, 0.4, 0.5]
867+
for u in (ut1, ut2), t in (ut1, ut2)
868+
interp = LinearInterpolation(ut1, ut2)
869+
for xs in (u, t)
870+
ys = @inferred(interp(xs))
871+
@test ys isa Vector{typeof(interp(first(xs)))}
872+
@test all(y == interp(x) for (x, y) in zip(xs, ys))
873+
end
868874
end
869875
end
870876

871877
@testset "Plugging vector timepoints" begin
872878
# Issue https://github.com/SciML/DataInterpolations.jl/issues/267
873879
t = Float64[1.0, 2.0, 3.0, 4.0, 5.0]
880+
x = Float64[1.3, 2.2, 4.1]
874881
@testset "utype - Vectors" begin
875882
interp = LinearInterpolation(rand(5), t)
876-
@test interp(t) isa Vector{Float64}
883+
y = interp(x)
884+
@test y isa Vector{Float64}
885+
@test length(y) == 3
877886
end
878887
@testset "utype - Vector of Vectors" begin
879888
interp = LinearInterpolation([rand(2) for _ in 1:5], t)
880-
@test interp(t) isa Vector{Vector{Float64}}
889+
y = interp(x)
890+
@test y isa Vector{Vector{Float64}}
891+
@test length(y) == 3
892+
@test all(length(yi) == 2 for yi in y)
881893
end
882894
@testset "utype - Matrix" begin
883895
interp = LinearInterpolation(rand(2, 5), t)
884-
@test interp(t) isa Matrix{Float64}
896+
y = interp(x)
897+
@test y isa Matrix{Float64}
898+
@test size(y) == (2, 3)
899+
end
900+
@testset "utype - Array" begin
901+
interp = LinearInterpolation(rand(2, 3, 4, 5), t)
902+
y = interp(x)
903+
@test y isa Array{Float64, 4}
904+
@test size(y) == (2, 3, 4, 3)
885905
end
886906
end
887907

0 commit comments

Comments
 (0)