Skip to content

Commit 1cb30aa

Browse files
authored
Add Spline{-1} support (#168)
* Add Spline{-1} support * add tests
1 parent c42fedf commit 1cb30aa

File tree

3 files changed

+75
-4
lines changed

3 files changed

+75
-4
lines changed

src/bases/splines.jl

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,41 @@ function getindex(B::HeavisideSpline{T}, x::Number, k::Int) where T
3939
x axes(B,1) && 1 k  size(B,2)|| throw(BoundsError())
4040

4141
p = B.points
42-
n = length(p)
43-
4442
p[k] < x < p[k+1] && return one(T)
4543
p[k] == x && return one(T)/2
4644
p[k+1] == x && return one(T)/2
4745
return zero(T)
4846
end
4947

48+
function getindex(B::Spline{-1,T}, x::Number, k::Int) where T
49+
x axes(B,1) && 1 k size(B,2)|| throw(BoundsError())
50+
51+
p = B.points
52+
p[k+1] == x && return convert(T,Inf)
53+
zero(T)
54+
end
55+
56+
57+
58+
grid(L::HeavisideSpline, n...) = L.points[1:end-1] .+ diff(L.points)/2
59+
plotgrid(L::HeavisideSpline, n...) = [L.points'; L.points'][2:end-1]
60+
function plotgridvalues(f::ApplyQuasiVector{<:Any,typeof(*),<:Tuple{HeavisideSpline,Any}})
61+
g = plotgrid(basis(f))
62+
c = coefficients(f)
63+
g,vec([c'; c'])
64+
end
65+
66+
function plotgrid(L::Spline{-1}, n...)
67+
p = L.points[2:end-1]
68+
vec([p'; p'; p'])
69+
end
70+
function plotgridvalues(f::ApplyQuasiVector{<:Any,typeof(*),<:Tuple{Spline{-1},Any}})
71+
g = plotgrid(basis(f))
72+
c = coefficients(f)
73+
g,vec([zeros(1,length(c)); c'; fill(NaN,1,length(c))])
74+
end
75+
76+
5077
# Splines sample same number of points regardless of length.
5178
grid(L::HeavisideSpline, ::Integer) = L.points[1:end-1] .+ diff(L.points)/2
5279
grid(L::LinearSpline, ::Integer) = L.points
@@ -88,6 +115,17 @@ function diff(L::LinearSpline{T}; dims::Integer=1) where T
88115
ApplyQuasiMatrix(*, HeavisideSpline{T}(x), D)
89116
end
90117

118+
function diff(L::HeavisideSpline{T}; dims::Integer=1) where T
119+
dims == 1 || error("not implemented")
120+
n = size(L,2)
121+
x = L.points
122+
D = BandedMatrix{T}(undef, (n-1,n), (0,1))
123+
d = diff(x)
124+
D[band(0)] .= -one(T)
125+
D[band(1)] .= one(T)
126+
ApplyQuasiMatrix(*, Spline{-1,T}(x), D)
127+
end
128+
91129

92130
##
93131
# sum
@@ -99,6 +137,7 @@ function _sum(A::HeavisideSpline, dims)
99137
end
100138

101139
function _sum(P::LinearSpline, dims)
140+
dims == 1 || error("not implemented")
102141
d = diff(P.points)
103142
ret = Array{float(eltype(d))}(undef, length(d)+1)
104143
ret[1] = d[1]/2
@@ -109,4 +148,10 @@ function _sum(P::LinearSpline, dims)
109148
permutedims(ret)
110149
end
111150

112-
_cumsum(H::HeavisideSpline{T}, dims) where T = LinearSpline(H.points) * tril(Ones{T}(length(H.points),length(H.points)-1) .* diff(H.points)',-1)
151+
function _sum(P::Spline{-1,T}, dims) where T
152+
dims == 1 || error("not implemented")
153+
Ones{T}(1, size(P,2))
154+
end
155+
156+
_cumsum(H::HeavisideSpline{T}, dims) where T = LinearSpline(H.points) * tril(Ones{T}(length(H.points),length(H.points)-1) .* diff(H.points)',-1)
157+
_cumsum(S::Spline{-1,T}, dims) where T = HeavisideSpline(S.points) * tril(ones(T,length(S.points)-1,length(S.points)-2),-1)

test/runtests.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ContinuumArrays, QuasiArrays, IntervalSets, DomainSets, FillArrays, LinearAlgebra, BandedMatrices, InfiniteArrays, Test, Base64
22
import ContinuumArrays: ℵ₁, materialize, AffineQuasiVector, BasisLayout, AdjointBasisLayout, SubBasisLayout, ℵ₁,
33
MappedBasisLayout, AdjointMappedBasisLayouts, MappedWeightedBasisLayout, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
4-
basis, invmap, Map, checkpoints, plotgrid, plotgrid_layout, mul, plotvalues
4+
basis, invmap, Map, checkpoints, plotgrid, plotgrid_layout, mul, plotvalues, plotgridvalues
55
import QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle, LazyQuasiArrayStyle
66
import LazyArrays: MemoryLayout, ApplyStyle, Applied, colsupport, arguments, ApplyLayout, LdivStyle, MulStyle
77

@@ -94,6 +94,14 @@ include("test_basisconcat.jl")
9494
a = affine(0..1, 1..5)
9595
v = L[a,:] * c
9696
@test plotvalues(v) == v[plotgrid(v)]
97+
98+
H = HeavisideSpline(1:5)
99+
u = H * (2:5)
100+
x,v = plotgridvalues(u)
101+
@test u[[1+4eps(),2-4eps(),2+4eps(),3-4eps(),3+4eps(),4-4eps(),4+4eps(),5-4eps()]] v
102+
103+
x,v = plotgridvalues(diff(u))
104+
@test x == [2,2,2,3,3,3,4,4,4]
97105
end
98106

99107
include("test_recipesbaseext.jl")

test/test_splines.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,4 +569,22 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout,
569569
@test Pl*(exp.(x .+ y')) plan_transform(L, Block(1,1), 2) * (plan_transform(L, Block(1,1), 1) * exp.(x .+ y'))
570570
end
571571
end
572+
573+
@testset "Dirac" begin
574+
H = HeavisideSpline(0:5)
575+
S = Spline{-1}(0:5)
576+
@test iszero(S[0.1,1])
577+
@test iszero(S[0.1,1:4])
578+
@test isinf(S[1,1])
579+
@test iszero(S[1,2])
580+
@test iszero(S[0,:])
581+
@test_throws BoundsError S[0.1,0]
582+
@test_throws BoundsError S[-1,1]
583+
584+
@test S \ diff(H) == diagm(0 => fill(-1,4), 1 => fill(1, 4))[1:end-1,:]
585+
586+
u = S * (1:4)
587+
@test sum(u) == 10
588+
@test cumsum(u)[5-4eps()] == 10
589+
end
572590
end

0 commit comments

Comments
 (0)