Skip to content

Commit 8f8641f

Browse files
edits for product and sum kernels: support for custom kernels, and tests with input_trait and constant constituents
1 parent e551682 commit 8f8641f

File tree

8 files changed

+76
-41
lines changed

8 files changed

+76
-41
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,16 @@ custom_rbf(x, y) = exp(-sum(abs2, x .- y)) # custom RBF implementation
9494
```
9595
To take advantage of some specialized structure-aware algorithms, it is prudent to let CovarianceFunctions.jl know about the input type, in this case
9696
```julia
97-
input_trait(::typeof(custom_rbf)) = IsotropicInput()
97+
CovarianceFunctions.input_trait(::typeof(custom_rbf)) = IsotropicInput()
9898
```
99+
Other possible options include `DotProductInput` or `StationaryLinearFunctionalInput`.
100+
To enable efficient output type inference for custom kernels with parameters,
101+
extend `Base.eltype`.
102+
Since the custom kernel above does not have any parameters, we set the type to the bottom type `Union{}`:
103+
```julia
104+
Base.eltype(k::typeof(custom_rbf)) = Union{}
105+
```
106+
The type of the output of the kernel `k` with inputs `x` and `y` is then expected to be `promote_type(eltype.((k, x, y))...)`.
99107

100108
## Toeplitz Structure
101109

src/algebra.jl

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
############################# kernel algebra ###################################
2-
# IDEA: separable sum gramian
3-
# IDEA: (Separable) Sum and Product could be one definition with meta programming
2+
# NOTE: output type inference of product, sum, and power not supported for
3+
# user-defined kernels unless Base.eltype is defined for them
44
################################ Product #######################################
5-
# TODO: constructors which merge products and sums
6-
struct Product{T, AT<:Tuple{Vararg{AbstractKernel}}} <: AbstractKernel{T}
5+
# IDEA: constructors which merge products and sums
6+
struct Product{T, AT<:Union{Tuple, AbstractVector}} <: AbstractKernel{T}
77
args::AT
8-
function Product(k::Tuple{Vararg{AbstractKernel}})
9-
T = promote_type(eltype.(k)...)
10-
new{T, typeof(k)}(k)
11-
end
8+
# input_traits : # could keep track of input_trait.(args)
9+
# input_trait # could keep track of the overall input trait
1210
end
1311
@functor Product
14-
(P::Product)(τ) = prod(k->k(τ), P.args) # TODO could check for isotropy here
12+
function Product(k::Union{Tuple, AbstractVector})
13+
T = promote_type(eltype.(k)...)
14+
Product{T, typeof(k)}(k)
15+
end
16+
Product(k...) = Product(k)
17+
(P::Product)(τ) = prod(k->k(τ), P.args) # IDEA could check for isotropy here
1518
(P::Product)(x, y) = prod(k->k(x, y), P.args)
16-
# (P::Product)(x, y) = isstationary(P) ? P(difference(x, y)) : prod(k->k(x, y), P.args)
19+
# (P::Product)(x, y) = isisotropic(P) ? P(difference(x, y)) : prod(k->k(x, y), P.args)
1720
Product(k::AbstractKernel...) = Product(k)
1821
Product(k::AbstractVector{<:AbstractKernel}) = Product(k...)
1922
Base.prod(k::AbstractVector{<:AbstractKernel}) = Product(k)
@@ -23,32 +26,38 @@ Base.:*(c::Number, k::AbstractKernel) = Constant(c) * k
2326
Base.:*(k::AbstractKernel, c::Number) = Constant(c) * k
2427

2528
################################### Sum ########################################
26-
struct Sum{T, AT<:Tuple{Vararg{AbstractKernel}}} <: AbstractKernel{T}
29+
struct Sum{T, AT<:Union{Tuple, AbstractVector}} <: AbstractKernel{T}
2730
args::AT
28-
function Sum(k::Tuple{Vararg{AbstractKernel}})
29-
T = promote_type(eltype.(k)...)
30-
new{T, typeof(k)}(k)
31-
end
31+
# input_trait # could keep track of the overall input trait
3232
end
3333
@functor Sum
34+
function Sum(k::Union{Tuple, AbstractVector})
35+
T = promote_type(eltype.(k)...)
36+
Sum{T, typeof(k)}(k)
37+
end
38+
Sum(k...) = Sum(k)
3439
(S::Sum)(τ) = sum(k->k(τ), S.args) # should only be called if S is stationary
3540
(S::Sum)(x, y) = sum(k->k(x, y), S.args)
3641
# (S::Sum)(τ) = isstationary(S) ? sum(k->k(τ), S.args) : error("One argument evaluation not possible for non-stationary kernel")
3742
# (S::Sum)(x, y) = isstationary(S) ? S(difference(x, y)) : sum(k->k(x, y), S.args)
38-
Sum(k::AbstractKernel...) = Sum(k)
39-
Sum(k::AbstractVector{<:AbstractKernel}) = Sum(k...)
43+
Sum(k...) = Sum(k)
4044
Base.sum(k::AbstractVector{<:AbstractKernel}) = Sum(k)
4145

4246
Base.:+(k::AbstractKernel...) = Sum(k)
4347
Base.:+(k::AbstractKernel, c::Number) = k + Constant(c)
4448
Base.:+(c::Number, k::AbstractKernel) = k + Constant(c)
4549

4650
################################## Power #######################################
47-
struct Power{T, K<:AbstractKernel{T}, PT} <: AbstractKernel{T}
51+
struct Power{T, K<:AbstractKernel} <: AbstractKernel{T}
4852
k::K
49-
p::PT
53+
p::Int
54+
# input_trait # could keep track of the overall input trait
5055
end
5156
@functor Power
57+
function Power(k, p::Int)
58+
T = promote_type(eltype(k))
59+
Power{T, typeof(k)}(k, p)
60+
end
5261
(P::Power)(τ) = P.k(τ)^P.p
5362
(P::Power)(x, y) = P.k(x, y)^P.p
5463
Base.:^(k::AbstractKernel, p::Number) = Power(k, p)
@@ -57,6 +66,7 @@ Base.:^(k::AbstractKernel, p::Number) = Power(k, p)
5766
# product kernel, but separately evaluates component kernels on different parts of the input
5867
struct SeparableProduct{T, K} <: AbstractKernel{T}
5968
args::K # kernel for input covariances
69+
# input_trait # could keep track of the overall input trait
6070
end
6171
@functor SeparableProduct
6272
SeparableProduct(k...) = SeparableProduct(k)

src/gradient.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,18 @@ function value_gradient_kernel(k, x, y, T::InputTrait = input_trait(G.k))
442442
value_gradient_kernel!(K, k, x, y, T)
443443
end
444444

445+
# IDEA: specialize first_gradient!(g, k, x, y) = ForwardDiff.gradient!(g, z->k(z, y), x)
446+
# computes covariance between value and gradient
447+
# function value_gradient_covariance!(gx, gy, k, x, y, ::IsotropicInput)
448+
# r² = sum(abs2, difference(x, y))
449+
# g .= derivative(k, r²)
450+
# end
451+
#
452+
# function value_gradient_covariance!(gx, gy, k, x, y, ::GenericInput())
453+
# r² = sum(abs2, difference(x, y))
454+
# g .= derivative(k, r²)
455+
# end
456+
445457
# IDEA: specialize evaluate for IsotropicInput, DotProductInput
446458
# returns block matrix
447459
function value_gradient_kernel!(K::DerivativeKernelElement, k, x, y, T::InputTrait)

src/gradient_algebra.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,25 @@ function gradient_kernel!(W::Woodbury, k::Product, x::AbstractVector, y::Abstrac
4949
# k_vec(x, y) = [h(x, y) for h in k.args] # include in loop
5050
# ForwardDiff.jacobian!(W.U', z->k_vec(z, y), x) # this is actually less allocating than the gradient! option
5151
# ForwardDiff.jacobian!(W.V, z->k_vec(x, z), y)
52-
52+
# GradientConfig() # for generic version, this could be pre-computed for efficiency gains
5353
r = length(k.args)
5454
for i in 1:r # parallelize this?
5555
h, H = k.args[i], A.args[i]
5656
hxy = h(x, y)
5757
D = H.args[1]
5858
D.diag .= prod_k_j / hxy
59+
# input_trait(h) could be pre-computed, or should not be passed here, because the factors might be composite kernels themselves
5960
H.args[2] = gradient_kernel!(H.args[2], h, x, y, input_trait(h))
6061

6162
ui, vi = @views W.U[:, i], W.V[i, :]
62-
ForwardDiff.gradient!(ui, z->h(z, y), x)
63-
ForwardDiff.gradient!(vi, z->h(x, z), y) # these are bottlenecks
63+
ForwardDiff.gradient!(ui, z->h(z, y), x) # these are bottlenecks
64+
ForwardDiff.gradient!(vi, z->h(x, z), y) # TODO: replace by value_gradient_covariance!
6465
@. ui *= prod_k_j / hxy
6566
@. vi /= hxy
6667
end
6768
return W
6869
end
6970

70-
# IDEA: specialize first_gradient!(g, k, x, y) = ForwardDiff.gradient!(g, z->k(z, y), x)
71-
# function first_gradient!(g, k, x, y, ::IsotropicInput)
72-
# r² = sum(abs2, difference(x, y))
73-
# g .= derivative(k, r²)
74-
# end
75-
7671
############################# Separable Product ################################
7772
# for product kernel with generic input
7873
function allocate_gradient_kernel(k::SeparableProduct, x::AbstractVector{<:Number},

src/properties.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function input_trait(S::ProductsAndSums)
5151
trait = input_trait(S.args[i]) # first non-constant kernel
5252
for j in i+1:length(S.args)
5353
k = S.args[j]
54-
if k isa Constant
54+
if k isa Constant # ignore constants, since they can function as any input type
5555
continue
5656
elseif input_trait(k) != trait # if the non-constant kernels don't have the same input type,
5757
return GenericInput() # we default back to GenericInput

src/stationary.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,22 @@
1212
############################# constant kernel ##################################
1313
# can be used to rescale existing kernels
1414
# IDEA: Allow Matrix-valued constant
15-
struct ConstantKernel{T} <: IsotropicKernel{T}
15+
struct Constant{T} <: IsotropicKernel{T}
1616
c::T
17-
function ConstantKernel(c, check::Bool = true)
17+
function Constant(c, check::Bool = true)
1818
if check && !ispsd(c)
1919
throw(DomainError("Constant is not positive semi-definite: $c"))
2020
end
2121
new{typeof(c)}(c)
2222
end
2323
end
24-
@functor ConstantKernel
25-
const Constant = ConstantKernel
26-
2724
# isisotropic(::Constant) = true
2825
# ismercer(k::Constant) = ispsd(k.c)
2926
# Constant(c) = Constant{typeof(c)}(c)
3027

3128
# should type of constant field and r agree? what promotion is necessary?
3229
# do we need the isotropic/ stationary evaluation, if we overwrite the mercer one?
30+
(k::Constant)() = k.c
3331
(k::Constant)(r²) = k.c # stationary / isotropic
3432
(k::Constant)(x, y) = k.c # mercer
3533

@@ -196,8 +194,9 @@ end
196194
struct CosineKernel{T, V<:Union{T, AbstractVector{T}}} <: StationaryKernel{T}
197195
c::V
198196
end
197+
@functor CosineKernel
199198
const Cosine = CosineKernel
200-
@functor Cosine
199+
const Cos = Cosine
201200

202201
# IDEA: trig-identity -> low-rank gramian
203202
# NOTE: this is the only stationary non-isotropic kernel so far

test/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using CovarianceFunctions
77
using CovarianceFunctions: EQ, RQ, Dot, ExponentialDot, NN, Matern, MaternP,
88
Lengthscale, input_trait, GradientKernel, ValueGradientKernel, GradientKernelElement,
99
DerivativeKernel, ValueDerivativeKernel, DerivativeKernelElement, Cosine,
10-
Woodbury, LazyMatrixProduct, ConstantKernel
10+
Woodbury, LazyMatrixProduct, Constant
1111

1212
const AbstractMatOrFac = Union{AbstractMatrix, Factorization}
1313

@@ -78,7 +78,7 @@ const AbstractMatOrFac = Union{AbstractMatrix, Factorization}
7878
@test W*a G*a
7979

8080
# testing constant kernel
81-
c = ConstantKernel(1)
81+
c = Constant(1)
8282
g = GradientKernel(c)
8383
@test g(x, y) zeros(d, d)
8484
end

test/properties.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
module TestProperties
22
using Test
33
using CovarianceFunctions
4-
using CovarianceFunctions: input_trait, DotProductInput, IsotropicInput, GenericInput
5-
using CovarianceFunctions: EQ, RQ, Exp, Dot, Poly, Line
4+
using CovarianceFunctions: input_trait, DotProductInput, IsotropicInput, StationaryLinearFunctionalInput, GenericInput
5+
using CovarianceFunctions: EQ, RQ, Exp, Dot, ExponentialDot, Poly, Line, Cos
66

77
using LinearAlgebra
88

99
@testset "properties" begin
10-
dot_kernels = [Dot(), Dot()^3] # , Line(1.), Poly(5, 1.)] # TODO: take care of constants
10+
dot_kernels = [Dot(), Dot()^3, ExponentialDot(), Line(1.), Poly(5, 1.)]
1111
for k in dot_kernels
1212
@test input_trait(k) isa DotProductInput
1313
end
@@ -19,6 +19,17 @@ using LinearAlgebra
1919

2020
k = CovarianceFunctions.NeuralNetwork()
2121
@test input_trait(k) isa GenericInput
22+
23+
# testing that constant kernels don't confuse the input_trait inference
24+
@test input_trait(1*EQ() + 1) isa IsotropicInput
25+
@test input_trait(1*EQ() + 2 + RQ(1.)*1) isa IsotropicInput
26+
27+
@test input_trait(1*Dot() + 1) isa DotProductInput
28+
@test input_trait(1*Dot() + 2 + Dot()^2*1) isa DotProductInput
29+
30+
w = randn()
31+
@test input_trait(1*Cos(w) + 1) isa StationaryLinearFunctionalInput
32+
@test input_trait(1*Cos(w) + 2 + Cos(w)^2*1) isa StationaryLinearFunctionalInput
2233
end
2334

2435
end

0 commit comments

Comments
 (0)