Skip to content

Commit f7fdb1e

Browse files
working version of fast taylor decomposition for gradient kernels
1 parent c420c1c commit f7fdb1e

File tree

3 files changed

+63
-29
lines changed

3 files changed

+63
-29
lines changed

src/CovarianceFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ include("taylor.jl")
6060
include("gradient.jl")
6161
include("gradient_algebra.jl")
6262
include("hessian.jl")
63-
# include("taylor_gradient.jl") # fast MVM algorithm for isotropic GradientKernel Gramians
63+
include("taylor_gradient.jl") # fast MVM algorithm for isotropic GradientKernel Gramians
6464
include("separable.jl")
6565

6666
end # CovarianceFunctions

src/barneshut.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function BarnesHutFactorization(k, x, y = x, D = nothing; θ::Real = 1/4, leafsi
3030
# w = zeros(length(m))
3131
# i = zeros(Bool, m)
3232
# WT, BT = typeof(w), typeof(i)
33-
T = gramian_eltype(k, xs, ys)
33+
T = gramian_eltype(k, xs[1], ys[1])
3434
BarnesHutFactorization{T, KT, XT, YT, TT, DT, RT}(k, xs, ys, Tree, D, θ) #, w, i)
3535
end
3636
function BarnesHutFactorization(G::Gramian, θ::Real = 1/2; leafsize::Int = BARNES_HUT_DEFAULT_LEAFSIZE)
@@ -49,7 +49,7 @@ function LinearAlgebra.mul!(y::AbstractVector, F::BarnesHutFactorization, x::Abs
4949
taylor!(y, F, x, α, β)
5050
end
5151
end
52-
function Base.:*(F::BarnesHutFactorization, x::AbstractVector)
52+
function Base.:*(F::BarnesHutFactorization{<:Number}, x::AbstractVector{<:Number})
5353
T = promote_type(eltype(F), eltype(x))
5454
y = zeros(T, size(F, 1))
5555
mul!(y, F, x)
@@ -148,45 +148,45 @@ end
148148

149149
############################# centers of mass ##################################
150150
# this is a weighted sum, could be generalized to incorporate node_sums
151-
function compute_centers_of_mass(x::AbstractVector, w::AbstractVector, T::BallTree)
151+
function compute_centers_of_mass(w::AbstractVector, x::AbstractVector, T::BallTree)
152152
D = eltype(x) <: StaticVector ? length(eltype(x)) : length(x[1]) # if x is static vector
153153
com = [zero(MVector{D, Float64}) for _ in 1:length(T.hyper_spheres)]
154-
compute_centers_of_mass!(com, x, w, T)
154+
compute_centers_of_mass!(com, w, x, T)
155155
end
156156

157157
function compute_centers_of_mass(F::BarnesHutFactorization, w::AbstractVector)
158-
compute_centers_of_mass(F.y, w, F.Tree)
158+
compute_centers_of_mass(w, F.y, F.Tree)
159159
end
160160

161-
function compute_centers_of_mass!(com::AbstractVector, x::AbstractVector, w::AbstractVector, T::BallTree)
161+
function compute_centers_of_mass!(com::AbstractVector, w::AbstractVector, x::AbstractVector, T::BallTree)
162162
abs_w = abs.(w)
163-
weighted_node_sums!(com, x, abs_w, T)
163+
weighted_node_sums!(com, abs_w, x, T)
164164
sum_w = node_sums(abs_w, T)
165165
ε = eps(eltype(w)) # ensuring division by zero it not a problem
166166
@. com ./= sum_w + ε
167167
end
168168

169-
node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(x, Ones(length(x)), T)
169+
node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(Ones(length(x)), x, T)
170170
function node_sums!(sums, x::AbstractVector, T::BallTree)
171-
weighted_node_sums!(sums, x, Ones(length(x)), T)
171+
weighted_node_sums!(sums, Ones(length(x)), x, T)
172172
end
173173

174-
function weighted_node_sums(x::AbstractVector, w::AbstractVector, T::BallTree, index::Int = 1)
174+
function weighted_node_sums(w::AbstractVector, x::AbstractVector, T::BallTree, index::Int = 1)
175175
length(x) == 0 && return zero(eltype(x))
176-
sums = zeros(typeof(w[1]'x[1]), length(T.hyper_spheres))
177-
weighted_node_sums!(sums, x, w, T)
176+
sums = fill(zero(w[1]'x[1]), length(T.hyper_spheres))
177+
weighted_node_sums!(sums, w, x, T)
178178
end
179179

180180
# NOTE: x should either be vector of numbers or vector of static arrays
181-
function weighted_node_sums!(sums::AbstractVector, x::AbstractVector,
182-
w::AbstractVector{<:Number}, T::BallTree, index::Int = 1)
181+
function weighted_node_sums!(sums::AbstractVector, w::AbstractVector,
182+
x::AbstractVector, T::BallTree, index::Int = 1)
183183
if isleaf(T.tree_data.n_internal_nodes, index)
184184
i = get_leaf_range(T.tree_data, index)
185185
wi, xi = @views w[T.indices[i]], x[T.indices[i]]
186186
sums[index] = wi'xi
187187
else
188-
task = @spawn weighted_node_sums!(sums, x, w, T, getleft(index))
189-
weighted_node_sums!(sums, x, w, T, getright(index))
188+
task = @spawn weighted_node_sums!(sums, w, x, T, getleft(index))
189+
weighted_node_sums!(sums, w, x, T, getright(index))
190190
wait(task)
191191
sums[index] = sums[getleft(index)] + sums[getright(index)]
192192
end

test/barneshut.jl

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
module TestBarnesHut
22
using LinearAlgebra
3+
using WoodburyFactorizations
34
using CovarianceFunctions
45
using CovarianceFunctions: BarnesHutFactorization, barneshut!, vector_of_static_vectors,
5-
node_sums, euclidean, GradientKernel, taylor!
6+
node_sums, euclidean, GradientKernel, taylor!, IsotropicGradientKernelElement
67
using NearestNeighbors
78
using NearestNeighbors: isleaf, getleft, getright, get_leaf_range
89
using Test
910

11+
# like barnes hut but puts 0 as far field contribution
1012
function barneshut_no_far_field!(b::AbstractVector, F::BarnesHutFactorization, w::AbstractVector,
1113
α::Number = 1, β::Number = 0, θ::Real = F.θ)
1214
D = length(eltype(F.x))
@@ -95,10 +97,9 @@ verbose = false
9597
nexp = 16
9698
err = zeros(nexp)
9799
err_no_split = zeros(nexp)
98-
err_hyp = zeros(nexp)
99-
err_hyp_no_split = zeros(nexp)
100100
err_nff = zeros(nexp)
101101
err_taylor = zeros(nexp)
102+
err_taylor_hyp = zeros(nexp)
102103
theta_array = range(1e-1, 1, length = nexp)
103104
for (i, θ) in enumerate(theta_array)
104105
barneshut!(b_bh, F, w, 1, 0, θ, split = true)
@@ -110,8 +111,11 @@ verbose = false
110111
barneshut_no_far_field!(b_bh, F, w, 1, 0, θ) # compare against pseudo barnes hut where far field = 0
111112
err_nff[i] = norm(b - b_bh)
112113

113-
taylor!(b_bh, F, w, 1, 0, θ) # compare against pseudo barnes hut where far field = 0
114+
taylor!(b_bh, F, w, 1, 0, θ, use_com = true)
114115
err_taylor[i] = norm(b - b_bh)
116+
117+
taylor!(b_bh, F, w, 1, 0, θ, use_com = false)
118+
err_taylor_hyp[i] = norm(b - b_bh)
115119
end
116120

117121
rel_err = err / norm(b)
@@ -122,28 +126,58 @@ verbose = false
122126
# using Plots
123127
# plotly()
124128
#
125-
# rel_err_nff = err_nff / norm(b)
126-
# rel_err_no_split = err_no_split / norm(b)
127-
# rel_err_taylor = err_taylor / norm(b)
128-
#
129+
# norm_b = norm(b)
130+
# rel_err_nff = err_nff / norm_b
131+
# rel_err_no_split = err_no_split / norm_b
132+
# rel_err_taylor = err_taylor / norm_b
133+
# rel_err_taylor_hyp = err_taylor_hyp / norm_b
134+
# #
129135
# plot(theta_array, rel_err, yscale = :log10, label = "barneshut", ylabel = "relative error", xlabel = "θ")
130136
# plot!(theta_array, rel_err_no_split, yscale = :log10, label = "no split")
131137
# plot!(theta_array, rel_err_nff, yscale = :log10, label = "sparse")
132138
# plot!(theta_array, rel_err_taylor, yscale = :log10, label = "taylor")
139+
# plot!(theta_array, rel_err_taylor_hyp, yscale = :log10, label = "taylor hyper-sphere centers")
133140
# gui()
134141

135142
end # testset weight vectors
136143

137-
138144
# @testset "gradient kernels" begin
139-
# k = CovarianceFunctions.Cauchy()
145+
#
146+
# n = 1024
147+
# d = 2
148+
# x = randn(d, n)
149+
# # k = CovarianceFunctions.Cauchy()
150+
# k = CovarianceFunctions.EQ()
140151
# g = CovarianceFunctions.GradientKernel(k)
141152
#
142153
# F = BarnesHutFactorization(g, x)
143154
# @test F isa BarnesHutFactorization
144-
# @test eltype(F) <: Diagonal
155+
# @test eltype(F) <: IsotropicGradientKernelElement
145156
# @test size(F) == (n, n)
146-
# @test size(F[1, 1]) == (d_out, d_out)
157+
# @test size(F[1, 1]) == (d, d)
158+
#
159+
# a = [randn(d) for _ in 1:n]
160+
# b = [zeros(d) for _ in 1:n]
161+
# # F*a
162+
# G = gramian(g, x)
163+
# b_truth = deepcopy(b)
164+
# # @time b_truth = G.A * a
165+
# # @time b_truth = G.A * a
166+
# mul!(b_truth, G.A, a)
167+
# norm_b = sqrt(sum(sum.(abs2, b_truth)))
168+
#
169+
# α, β = 1, 0
170+
# θ = 0
171+
# taylor!(b, F, a, α, β, θ)
172+
# err = sqrt(sum(sum.(abs2, b - b_truth)))
173+
# rel_err = err / norm_b
174+
# @test rel_err < 1e-10
175+
#
176+
# θ = 1/10
177+
# taylor!(b, F, a, α, β, θ)
178+
# err = sqrt(sum(sum.(abs2, b - b_truth)))
179+
# rel_err = err / norm_b
180+
# @test rel_err < 1e-3
147181
# end # testset matrix valued bh
148182

149183
end # testset

0 commit comments

Comments
 (0)