Skip to content

Commit a561a86

Browse files
Add optimized implementations for ZeroKernel and ConstantKernel (#439)
* Add optimized implementations for ZeroKernel and ConstantKernel * Bump version * Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 7d74366 commit a561a86

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.32"
3+
version = "0.10.33"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/basekernels/constant.jl

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,43 @@ See also: [`ConstantKernel`](@ref)
1515
"""
1616
struct ZeroKernel <: SimpleKernel end
1717

18-
kappa::ZeroKernel, d::T) where {T<:Real} = zero(T)
19-
18+
# SimpleKernel interface
19+
kappa(::ZeroKernel, ::Real) = false
2020
metric(::ZeroKernel) = Delta()
2121

22+
# Optimizations
23+
(::ZeroKernel)(x, y) = false
24+
kernelmatrix(::ZeroKernel, x::AbstractVector) = Falses(length(x), length(x))
25+
function kernelmatrix(::ZeroKernel, x::AbstractVector, y::AbstractVector)
26+
validate_inputs(x, y)
27+
return Falses(length(x), length(y))
28+
end
29+
function kernelmatrix!(K::AbstractMatrix, ::ZeroKernel, x::AbstractVector)
30+
validate_inplace_dims(K, x)
31+
return fill!(K, zero(eltype(K)))
32+
end
33+
function kernelmatrix!(
34+
K::AbstractMatrix, ::ZeroKernel, x::AbstractVector, y::AbstractVector
35+
)
36+
validate_inplace_dims(K, x, y)
37+
return fill!(K, zero(eltype(K)))
38+
end
39+
kernelmatrix_diag(::ZeroKernel, x::AbstractVector) = Falses(length(x))
40+
function kernelmatrix_diag(::ZeroKernel, x::AbstractVector, y::AbstractVector)
41+
validate_inputs(x, y)
42+
return Falses(length(x))
43+
end
44+
function kernelmatrix_diag!(K::AbstractVector, ::ZeroKernel, x::AbstractVector)
45+
validate_inplace_dims(K, x)
46+
return fill!(K, zero(eltype(K)))
47+
end
48+
function kernelmatrix_diag!(
49+
K::AbstractVector, ::ZeroKernel, x::AbstractVector, y::AbstractVector
50+
)
51+
validate_inplace_dims(K, x, y)
52+
return fill!(K, zero(eltype(K)))
53+
end
54+
2255
Base.show(io::IO, ::ZeroKernel) = print(io, "Zero Kernel")
2356

2457
"""
@@ -73,14 +106,41 @@ end
73106

74107
@functor ConstantKernel
75108

76-
kappa::ConstantKernel, x::Real) = only.c) * one(x)
77-
109+
# SimpleKernel interface
110+
kappa::ConstantKernel, ::Real) = only.c)
78111
metric(::ConstantKernel) = Delta()
79112

113+
# Optimizations
114+
(k::ConstantKernel)(x, y) = only(k.c)
80115
kernelmatrix(k::ConstantKernel, x::AbstractVector) = Fill(only(k.c), length(x), length(x))
81-
82116
function kernelmatrix(k::ConstantKernel, x::AbstractVector, y::AbstractVector)
117+
validate_inputs(x, y)
83118
return Fill(only(k.c), length(x), length(y))
84119
end
120+
function kernelmatrix!(K::AbstractMatrix, k::ConstantKernel, x::AbstractVector)
121+
validate_inplace_dims(K, x)
122+
return fill!(K, only(k.c))
123+
end
124+
function kernelmatrix!(
125+
K::AbstractMatrix, k::ConstantKernel, x::AbstractVector, y::AbstractVector
126+
)
127+
validate_inplace_dims(K, x, y)
128+
return fill!(K, only(k.c))
129+
end
130+
kernelmatrix_diag(k::ConstantKernel, x::AbstractVector) = Fill(only(k.c), length(x))
131+
function kernelmatrix_diag(k::ConstantKernel, x::AbstractVector, y::AbstractVector)
132+
validate_inputs(x, y)
133+
return Fill(only(k.c), length(x))
134+
end
135+
function kernelmatrix_diag!(K::AbstractVector, k::ConstantKernel, x::AbstractVector)
136+
validate_inplace_dims(K, x)
137+
return fill!(K, only(k.c))
138+
end
139+
function kernelmatrix_diag!(
140+
K::AbstractVector, k::ConstantKernel, x::AbstractVector, y::AbstractVector
141+
)
142+
validate_inplace_dims(K, x, y)
143+
return fill!(K, only(k.c))
144+
end
85145

86146
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", only.c), ")")

0 commit comments

Comments
 (0)