Skip to content

Commit 2ad5529

Browse files
simplifying fix for GradientKernelElement
1 parent 7c95b1c commit 2ad5529

File tree

4 files changed

+54
-67
lines changed

4 files changed

+54
-67
lines changed

Manifest.toml

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,16 @@ version = "2.3.0"
3131
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
3232

3333
[[deps.ArrayInterface]]
34-
deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
35-
git-tree-sha1 = "81f0cb60dc994ca17f68d9fb7c942a5ae70d9ee4"
34+
deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "Static"]
35+
git-tree-sha1 = "a24db3a330d0ff64789abd52a26c732805619a53"
3636
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
37-
version = "5.0.8"
37+
version = "6.0.5"
38+
39+
[[deps.ArrayInterfaceCore]]
40+
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
41+
git-tree-sha1 = "d3a275e927d411e054c4192e5aca03998c233e94"
42+
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
43+
version = "0.1.7"
3844

3945
[[deps.ArrayLayouts]]
4046
deps = ["FillArrays", "LinearAlgebra", "SparseArrays"]
@@ -84,21 +90,21 @@ version = "0.4.2"
8490

8591
[[deps.CUDA]]
8692
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
87-
git-tree-sha1 = "19fb33957a5f85efb3cc10e70cf4dd4e30174ac9"
93+
git-tree-sha1 = "925a16b909fdae16920c1319feadecffb6695b9d"
8894
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
89-
version = "3.10.0"
95+
version = "3.10.1"
9096

9197
[[deps.ChainRules]]
9298
deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"]
93-
git-tree-sha1 = "de68815ccf15c7d3e3e3338f0bd3a8a0528f9b9f"
99+
git-tree-sha1 = "e9023f88b1655ffc6a4aaef2502878e8116151ef"
94100
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
95-
version = "1.33.0"
101+
version = "1.35.1"
96102

97103
[[deps.ChainRulesCore]]
98104
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
99-
git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f"
105+
git-tree-sha1 = "9489214b993cd42d17f44c36e359bf6a7c919abf"
100106
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
101-
version = "1.14.0"
107+
version = "1.15.0"
102108

103109
[[deps.ChangesOfVariables]]
104110
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
@@ -114,9 +120,9 @@ version = "0.3.0"
114120

115121
[[deps.Compat]]
116122
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
117-
git-tree-sha1 = "b153278a25dd42c65abbf4e62344f9d22e59191b"
123+
git-tree-sha1 = "87e84b2293559571802f97dd9c94cfd6be52c5e5"
118124
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
119-
version = "3.43.0"
125+
version = "3.44.0"
120126

121127
[[deps.CompilerSupportLibraries_jll]]
122128
deps = ["Artifacts", "Libdl"]
@@ -146,9 +152,9 @@ version = "1.10.0"
146152

147153
[[deps.DataStructures]]
148154
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
149-
git-tree-sha1 = "cc1a8e22627f33c789ab60b36a9132ac050bbf75"
155+
git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0"
150156
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
151-
version = "0.18.12"
157+
version = "0.18.13"
152158

153159
[[deps.DataValueInterfaces]]
154160
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
@@ -237,9 +243,9 @@ version = "0.13.2"
237243

238244
[[deps.Flux]]
239245
deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"]
240-
git-tree-sha1 = "f84e50845ab88702c721dc7c6129a85cbc1de332"
246+
git-tree-sha1 = "62350a872545e1369b1d8f11358a21681aa73929"
241247
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
242-
version = "0.13.1"
248+
version = "0.13.3"
243249

244250
[[deps.FoldsThreads]]
245251
deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
@@ -317,9 +323,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
317323

318324
[[deps.Intervals]]
319325
deps = ["Dates", "Printf", "RecipesBase", "Serialization", "TimeZones"]
320-
git-tree-sha1 = "b993074580045d1551d30990dc0fa5ba6feef92b"
326+
git-tree-sha1 = "1fd6fccdbdccee5997fb245289d98386c8996180"
321327
uuid = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
322-
version = "1.6.0"
328+
version = "1.7.0"
323329

324330
[[deps.InverseFunctions]]
325331
deps = ["Test"]
@@ -363,9 +369,9 @@ version = "1.1.1"
363369

364370
[[deps.LLVM]]
365371
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
366-
git-tree-sha1 = "c8d47589611803a0f3b4813d9e267cd4e3dbcefb"
372+
git-tree-sha1 = "10a20c556107dc5833d3bb7c5e45c4a6e191bd28"
367373
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
368-
version = "4.11.1"
374+
version = "4.13.0"
369375

370376
[[deps.LLVMExtra_jll]]
371377
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
@@ -428,9 +434,9 @@ uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
428434
version = "2022.0.0+0"
429435

430436
[[deps.MLStyle]]
431-
git-tree-sha1 = "e49789e5eb7b2d5577aaea395bfcac769df64bb8"
437+
git-tree-sha1 = "2041c1fd6833b3720d363c3ea8140bffaf86d9c4"
432438
uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
433-
version = "0.4.11"
439+
version = "0.4.12"
434440

435441
[[deps.MLUtils]]
436442
deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"]
@@ -506,9 +512,9 @@ version = "0.8.5"
506512

507513
[[deps.NNlibCUDA]]
508514
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
509-
git-tree-sha1 = "0d18b4c80a92a00d3d96e8f9677511a7422a946e"
515+
git-tree-sha1 = "e161b835c6aa9e2339c1e72c3d4e39891eac7a4f"
510516
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
511-
version = "0.2.2"
517+
version = "0.2.3"
512518

513519
[[deps.NaNMath]]
514520
git-tree-sha1 = "737a5957f387b17e74d4ad2f440eb330b39a62c5"
@@ -546,9 +552,9 @@ version = "0.5.5+0"
546552

547553
[[deps.Optimisers]]
548554
deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
549-
git-tree-sha1 = "2442c3ddbda547c80e8b6451a103719d6a3593dd"
555+
git-tree-sha1 = "26f58049054343c8103d67a5530284a35f1186cb"
550556
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
551-
version = "0.2.4"
557+
version = "0.2.5"
552558

553559
[[deps.OrderedCollections]]
554560
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
@@ -682,9 +688,9 @@ version = "0.1.14"
682688

683689
[[deps.Static]]
684690
deps = ["IfElse"]
685-
git-tree-sha1 = "3a2a99b067090deb096edecec1dc291c5b4b31cb"
691+
git-tree-sha1 = "5d2c08cef80c7a3a8ba9ca023031a85c263012c5"
686692
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
687-
version = "0.6.5"
693+
version = "0.6.6"
688694

689695
[[deps.StaticArrays]]
690696
deps = ["LinearAlgebra", "Random", "Statistics"]
@@ -708,6 +714,10 @@ git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291"
708714
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
709715
version = "0.33.16"
710716

717+
[[deps.SuiteSparse]]
718+
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
719+
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
720+
711721
[[deps.SymEngine]]
712722
deps = ["Compat", "Libdl", "LinearAlgebra", "RecipesBase", "SpecialFunctions", "SymEngine_jll"]
713723
git-tree-sha1 = "6cf88a0b98c758a36e6e978a41e8a12f6f5cdacc"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CovarianceFunctions"
22
uuid = "b3329135-7958-41d4-bb02-e084c5a526bf"
33
authors = ["sebastianament"]
4-
version = "0.3.5"
4+
version = "0.3.6"
55

66
[deps]
77
BesselK = "432ab697-7a72-484f-bc4a-bc531f5c819b"

src/gradient.jl

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,27 @@ Base.eltype(K::GradientKernelElement{T}) where T = T
6060
# gradient kernel element only used for sparsely representable elements
6161
Base.Matrix(K::GradientKernelElement) = K * I(size(K, 1))
6262

63+
function GradientKernelElement{T}(k, x, y, it::InputTrait) where T
64+
GradientKernelElement{T, typeof(k), typeof(x), typeof(y), typeof(it)}(k, x, y, it)
65+
end
66+
67+
function gradient_kernel(k, x, y, it::InputTrait)
68+
T = gramian_eltype(k, x, y)
69+
GradientKernelElement{T}(k, x, y, it)
70+
end
71+
72+
function gradient_kernel!(K::GradientKernelElement, k, x, y, it::InputTrait)
73+
GradientKernelElement{eltype(K)}(k, x, y, it)
74+
end
75+
6376
function Base.:*(G::GradientKernelElement, a)
6477
T = promote_type(eltype(G), eltype(a))
6578
b = zeros(T, size(a))
6679
mul!(b, G, a)
6780
end
6881

6982
const GenericGradientKernelElement{T, K, X, Y} = GradientKernelElement{T, K, X, Y, <:GenericInput}
70-
7183
const IsotropicGradientKernelElement{T, K, X, Y} = GradientKernelElement{T, K, X, Y, IsotropicInput}
72-
function IsotropicGradientKernelElement{T}(k, x, y) where T
73-
IsotropicGradientKernelElement{T, typeof(k), typeof(x), typeof(y)}(k, x, y, IsotropicInput())
74-
end
7584

7685
# isotropic kernel
7786
function LinearAlgebra.mul!(b, G::IsotropicGradientKernelElement, a, α::Number = 1, β::Number = 0) #, ::IsotropicInput = G.input_trait)
@@ -95,19 +104,8 @@ function WoodburyFactorizations.Woodbury(K::IsotropicGradientKernelElement)
95104
return K = Woodbury(D, r, C, r')
96105
end
97106

98-
function gradient_kernel!(K::IsotropicGradientKernelElement, k, x, y, ::IsotropicInput)
99-
typeof(K)(k, x, y, IsotropicInput())
100-
end
101-
102-
function gradient_kernel(k, x, y, ::IsotropicInput)
103-
T = gramian_eltype(k, x, y)
104-
IsotropicGradientKernelElement{T}(k, x, y)
105-
end
106-
107107
const DotProductGradientKernelElement{T, K, X, Y} = GradientKernelElement{T, K, X, Y, DotProductInput}
108-
function DotProductGradientKernelElement{T}(k, x, y) where T
109-
DotProductGradientKernelElement{T, typeof(k), typeof(x), typeof(y)}(k, x, y, DotProductInput())
110-
end
108+
111109
function LinearAlgebra.mul!(b, K::DotProductGradientKernelElement, a, α::Number = 1, β::Number = 0)
112110
k, x, y = K.k, K.x, K.y
113111
= dot(x, y)
@@ -126,19 +124,7 @@ function WoodburyFactorizations.Woodbury(K::DotProductGradientKernelElement)
126124
return K = Woodbury(D, copy(y), C, copy(x)')
127125
end
128126

129-
function gradient_kernel!(K::DotProductGradientKernelElement, k, x, y, ::DotProductInput)
130-
typeof(K)(k, x, y, DotProductInput())
131-
end
132-
133-
function gradient_kernel(k, x, y, ::DotProductInput)
134-
T = gramian_eltype(k, x, y)
135-
DotProductGradientKernelElement{T}(k, x, y)
136-
end
137-
138127
const LinearFunctionalGradientKernelElement{T, K, X, Y} = GradientKernelElement{T, K, X, Y, StationaryLinearFunctionalInput}
139-
function LinearFunctionalGradientKernelElement{T}(k, x, y) where T
140-
LinearFunctionalGradientKernelElement{T, typeof(k), typeof(x), typeof(y)}(k, x, y, StationaryLinearFunctionalInput())
141-
end
142128

143129
function LinearAlgebra.mul!(b, K::LinearFunctionalGradientKernelElement, a, α::Number = 1, β::Number = 0)
144130
k, x, y = K.k, K.x, K.y
@@ -162,16 +148,6 @@ function LazyMatrixProduct(K::LinearFunctionalGradientKernelElement)
162148
return LazyMatrixProduct(c, c2')
163149
end
164150

165-
# is this necessary?
166-
function gradient_kernel!(K::LinearFunctionalGradientKernelElement, k, x, y, ::StationaryLinearFunctionalInput)
167-
typeof(K)(k, x, y, StationaryLinearFunctionalInput())
168-
end
169-
170-
function gradient_kernel(k, x, y, ::StationaryLinearFunctionalInput)
171-
T = gramian_eltype(k, x, y)
172-
LinearFunctionalGradientKernelElement{T}(k, x, y)
173-
end
174-
175151
function evaluate_block!(Gij, k::GradientKernel, x, y, IT = input_trait(k))
176152
gradient_kernel!(Gij, k.k, x, y, IT)
177153
end
@@ -420,7 +396,7 @@ end
420396
################################################################################
421397
# [f, ∂f] ∼ GP([μ, ∂μ], dK) # value + gradient kernel
422398
# IDEA: For efficiency, maybe create ValueGradientKernelElement like in hessian.jl
423-
# currently, this is an order of magnitude slower than GradientKernel
399+
# might not be necessary anymore, benchmark against GradientKernel
424400
struct ValueGradientKernel{T, K, IT<:InputTrait} <: AbstractDerivativeKernel{T, K}
425401
k::K
426402
input_trait::IT

src/gradient_algebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
############################ Gradient Algebra ##################################
2+
# IDEA: could have specialization for gradient kernels of Power kernels of composite kernels
23
################################### Sum ########################################
34
# allocates space for gradient kernel evaluation but does not evaluate
45
# the separation from evaluation is useful for ValueGradientKernel

0 commit comments

Comments
 (0)