Skip to content

Commit 62749c6

Browse files
Merge pull request #247 from sathvikbhagavan/sb/sym_second_order
refactor!: remove indexing dispatches and add dispatch for higher order derivatives with Symbolics
2 parents b1f9851 + 40a6d96 commit 62749c6

10 files changed

+93
-97
lines changed

ext/DataInterpolationsOptimExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function Curvefit(u,
3131
mfit = optimize(od, lb, ub, p0, Fminbox(alg))
3232
end
3333
pmin = Optim.minimizer(mfit)
34-
CurvefitCache{true}(u, t, model, p0, ub, lb, alg, pmin, extrapolate)
34+
CurvefitCache(u, t, model, p0, ub, lb, alg, pmin, extrapolate)
3535
end
3636

3737
# Curvefit

ext/DataInterpolationsRegularizationToolsExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Abstrac
7575
Wls½ = LA.diagm(sqrt.(wls))
7676
Wr½ = LA.diagm(sqrt.(wr))
7777
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
78-
RegularizationSmooth{true}(u, û, t, t̂, wls, wr, d, λ, alg, Aitp, extrapolate)
78+
RegularizationSmooth(u, û, t, t̂, wls, wr, d, λ, alg, Aitp, extrapolate)
7979
end
8080
"""
8181
Direct smoothing, no `t̂` or weights
@@ -94,7 +94,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2;
9494
Wls½ = Array{Float64}(LA.I, N, N)
9595
Wr½ = Array{Float64}(LA.I, N - d, N - d)
9696
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
97-
RegularizationSmooth{true}(u,
97+
RegularizationSmooth(u,
9898
û,
9999
t,
100100
t̂,
@@ -121,7 +121,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Abstrac
121121
Wls½ = Array{Float64}(LA.I, N, N)
122122
Wr½ = Array{Float64}(LA.I, N̂ - d, N̂ - d)
123123
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
124-
RegularizationSmooth{true}(u,
124+
RegularizationSmooth(u,
125125
û,
126126
t,
127127
t̂,
@@ -149,7 +149,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Abstrac
149149
Wls½ = LA.diagm(sqrt.(wls))
150150
Wr½ = Array{Float64}(LA.I, N̂ - d, N̂ - d)
151151
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
152-
RegularizationSmooth{true}(u,
152+
RegularizationSmooth(u,
153153
û,
154154
t,
155155
t̂,
@@ -179,7 +179,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing
179179
Wls½ = LA.diagm(sqrt.(wls))
180180
Wr½ = Array{Float64}(LA.I, N - d, N - d)
181181
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
182-
RegularizationSmooth{true}(u,
182+
RegularizationSmooth(u,
183183
û,
184184
t,
185185
t̂,
@@ -209,7 +209,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing
209209
Wls½ = LA.diagm(sqrt.(wls))
210210
Wr½ = LA.diagm(sqrt.(wr))
211211
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
212-
RegularizationSmooth{true}(u,
212+
RegularizationSmooth(u,
213213
û,
214214
t,
215215
t̂,
@@ -240,7 +240,7 @@ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing
240240
Wls½ = LA.diagm(sqrt.(wls))
241241
Wr½ = LA.diagm(sqrt.(wr))
242242
û, λ, Aitp = _reg_smooth_solve(u, t̂, d, M, Wls½, Wr½, λ, alg, extrapolate)
243-
RegularizationSmooth{true}(u,
243+
RegularizationSmooth(u,
244244
û,
245245
t,
246246
t̂,

ext/DataInterpolationsSymbolicsExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ function derivative(interp::AbstractInterpolation, t::Num, order = 1)
2121
end
2222
SymbolicUtils.promote_symtype(::typeof(derivative), _...) = Real
2323

24+
function Symbolics.derivative(::typeof(derivative), args::NTuple{3, Any}, ::Val{2})
25+
Symbolics.unwrap(derivative(args[1], Symbolics.wrap(args[2]), args[3] + 1))
26+
end
27+
2428
function Symbolics.derivative(interp::AbstractInterpolation, args::NTuple{1, Any}, ::Val{1})
2529
Symbolics.unwrap(derivative(interp, Symbolics.wrap(args[1])))
2630
end

src/DataInterpolations.jl

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,7 @@ module DataInterpolations
22

33
### Interface Functionality
44

5-
abstract type AbstractInterpolation{FT, T} <: AbstractVector{T} end
6-
7-
Base.size(A::AbstractInterpolation) = size(A.u)
8-
Base.size(A::AbstractInterpolation{true}) = length(A.u) .+ size(A.t)
9-
Base.getindex(A::AbstractInterpolation, i) = A.u[i]
10-
function Base.getindex(A::AbstractInterpolation{true}, i)
11-
i <= length(A.u) ? A.u[i] : A.t[i - length(A.u)]
12-
end
13-
Base.setindex!(A::AbstractInterpolation, x, i) = A.u[i] = x
14-
function Base.setindex!(A::AbstractInterpolation{true}, x, i)
15-
i <= length(A.u) ? (A.u[i] = x) : (A.t[i - length(A.u)] = x)
16-
end
5+
abstract type AbstractInterpolation{T} end
176

187
using LinearAlgebra, RecipesBase
198
using PrettyTables
@@ -67,7 +56,7 @@ export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,
6756

6857
# added for RegularizationSmooth, JJS 11/27/21
6958
### Regularization data smoothing and interpolation
70-
struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT, T}
59+
struct RegularizationSmooth{uType, tType, T, T2} <: AbstractInterpolation{T}
7160
u::uType
7261
::uType
7362
t::tType
@@ -77,9 +66,9 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT
7766
d::Int # derivative degree used to calculate the roughness
7867
λ::T2 # regularization parameter
7968
alg::Symbol # how to determine λ: `:fixed`, `:gcv_svd`, `:gcv_tr`, `L_curve`
80-
Aitp::AbstractInterpolation{FT, T}
69+
Aitp::AbstractInterpolation{T}
8170
extrapolate::Bool
82-
function RegularizationSmooth{FT}(u,
71+
function RegularizationSmooth(u,
8372
û,
8473
t,
8574
t̂,
@@ -89,8 +78,8 @@ struct RegularizationSmooth{uType, tType, FT, T, T2} <: AbstractInterpolation{FT
8978
λ,
9079
alg,
9180
Aitp,
92-
extrapolate) where {FT}
93-
new{typeof(u), typeof(t), FT, eltype(u), typeof(λ)}(u,
81+
extrapolate)
82+
new{typeof(u), typeof(t), eltype(u), typeof(λ)}(u,
9483
û,
9584
t,
9685
t̂,
@@ -116,9 +105,8 @@ struct CurvefitCache{
116105
lbType,
117106
algType,
118107
pminType,
119-
FT,
120108
T
121-
} <: AbstractInterpolation{FT, T}
109+
} <: AbstractInterpolation{T}
122110
u::uType
123111
t::tType
124112
m::mType # model type
@@ -128,10 +116,10 @@ struct CurvefitCache{
128116
alg::algType # alg to optimize cost function
129117
pmin::pminType # optimized params
130118
extrapolate::Bool
131-
function CurvefitCache{FT}(u, t, m, p0, ub, lb, alg, pmin, extrapolate) where {FT}
119+
function CurvefitCache(u, t, m, p0, ub, lb, alg, pmin, extrapolate)
132120
new{typeof(u), typeof(t), typeof(m),
133121
typeof(p0), typeof(ub), typeof(lb),
134-
typeof(alg), typeof(pmin), FT, eltype(u)}(u,
122+
typeof(alg), typeof(pmin), eltype(u)}(u,
135123
t,
136124
m,
137125
p0,
@@ -150,7 +138,4 @@ end
150138

151139
export Curvefit
152140

153-
# Deprecated April 2020
154-
export ZeroSpline
155-
156141
end # module

src/integrals.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number)
1414
if A.t[idx2] == t2
1515
idx2 -= 1
1616
end
17-
total = zero(eltype(A))
17+
total = zero(eltype(A.u))
1818
for idx in idx1:idx2
1919
lt1 = idx == idx1 ? t1 : A.t[idx]
2020
lt2 = idx == idx2 ? t2 : A.t[idx + 1]

0 commit comments

Comments
 (0)