Skip to content

Commit 6045925

Browse files
working version of fast taylor decomposition for gradient kernels
1 parent c420c1c commit 6045925

File tree

4 files changed

+200
-36
lines changed

4 files changed

+200
-36
lines changed

src/CovarianceFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ include("taylor.jl")
6060
include("gradient.jl")
6161
include("gradient_algebra.jl")
6262
include("hessian.jl")
63-
# include("taylor_gradient.jl") # fast MVM algorithm for isotropic GradientKernel Gramians
63+
include("taylor_gradient.jl") # fast MVM algorithm for isotropic GradientKernel Gramians
6464
include("separable.jl")
6565

6666
end # CovarianceFunctions

src/barneshut.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function BarnesHutFactorization(k, x, y = x, D = nothing; θ::Real = 1/4, leafsi
3030
# w = zeros(length(m))
3131
# i = zeros(Bool, m)
3232
# WT, BT = typeof(w), typeof(i)
33-
T = gramian_eltype(k, xs, ys)
33+
T = gramian_eltype(k, xs[1], ys[1])
3434
BarnesHutFactorization{T, KT, XT, YT, TT, DT, RT}(k, xs, ys, Tree, D, θ) #, w, i)
3535
end
3636
function BarnesHutFactorization(G::Gramian, θ::Real = 1/2; leafsize::Int = BARNES_HUT_DEFAULT_LEAFSIZE)
@@ -49,7 +49,7 @@ function LinearAlgebra.mul!(y::AbstractVector, F::BarnesHutFactorization, x::Abs
4949
taylor!(y, F, x, α, β)
5050
end
5151
end
52-
function Base.:*(F::BarnesHutFactorization, x::AbstractVector)
52+
function Base.:*(F::BarnesHutFactorization{<:Number}, x::AbstractVector{<:Number})
5353
T = promote_type(eltype(F), eltype(x))
5454
y = zeros(T, size(F, 1))
5555
mul!(y, F, x)
@@ -148,45 +148,45 @@ end
148148

149149
############################# centers of mass ##################################
150150
# this is a weighted sum, could be generalized to incorporate node_sums
151-
function compute_centers_of_mass(x::AbstractVector, w::AbstractVector, T::BallTree)
151+
function compute_centers_of_mass(w::AbstractVector, x::AbstractVector, T::BallTree)
152152
D = eltype(x) <: StaticVector ? length(eltype(x)) : length(x[1]) # if x is static vector
153153
com = [zero(MVector{D, Float64}) for _ in 1:length(T.hyper_spheres)]
154-
compute_centers_of_mass!(com, x, w, T)
154+
compute_centers_of_mass!(com, w, x, T)
155155
end
156156

157157
function compute_centers_of_mass(F::BarnesHutFactorization, w::AbstractVector)
158-
compute_centers_of_mass(F.y, w, F.Tree)
158+
compute_centers_of_mass(w, F.y, F.Tree)
159159
end
160160

161-
function compute_centers_of_mass!(com::AbstractVector, x::AbstractVector, w::AbstractVector, T::BallTree)
161+
function compute_centers_of_mass!(com::AbstractVector, w::AbstractVector, x::AbstractVector, T::BallTree)
162162
abs_w = abs.(w)
163-
weighted_node_sums!(com, x, abs_w, T)
163+
weighted_node_sums!(com, abs_w, x, T)
164164
sum_w = node_sums(abs_w, T)
165165
ε = eps(eltype(w)) # ensuring division by zero it not a problem
166166
@. com ./= sum_w + ε
167167
end
168168

169-
node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(x, Ones(length(x)), T)
169+
node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(Ones(length(x)), x, T)
170170
function node_sums!(sums, x::AbstractVector, T::BallTree)
171-
weighted_node_sums!(sums, x, Ones(length(x)), T)
171+
weighted_node_sums!(sums, Ones(length(x)), x, T)
172172
end
173173

174-
function weighted_node_sums(x::AbstractVector, w::AbstractVector, T::BallTree, index::Int = 1)
174+
function weighted_node_sums(w::AbstractVector, x::AbstractVector, T::BallTree, index::Int = 1)
175175
length(x) == 0 && return zero(eltype(x))
176-
sums = zeros(typeof(w[1]'x[1]), length(T.hyper_spheres))
177-
weighted_node_sums!(sums, x, w, T)
176+
sums = fill(zero(w[1]'x[1]), length(T.hyper_spheres))
177+
weighted_node_sums!(sums, w, x, T)
178178
end
179179

180180
# NOTE: x should either be vector of numbers or vector of static arrays
181-
function weighted_node_sums!(sums::AbstractVector, x::AbstractVector,
182-
w::AbstractVector{<:Number}, T::BallTree, index::Int = 1)
181+
function weighted_node_sums!(sums::AbstractVector, w::AbstractVector,
182+
x::AbstractVector, T::BallTree, index::Int = 1)
183183
if isleaf(T.tree_data.n_internal_nodes, index)
184184
i = get_leaf_range(T.tree_data, index)
185185
wi, xi = @views w[T.indices[i]], x[T.indices[i]]
186186
sums[index] = wi'xi
187187
else
188-
task = @spawn weighted_node_sums!(sums, x, w, T, getleft(index))
189-
weighted_node_sums!(sums, x, w, T, getright(index))
188+
task = @spawn weighted_node_sums!(sums, w, x, T, getleft(index))
189+
weighted_node_sums!(sums, w, x, T, getright(index))
190190
wait(task)
191191
sums[index] = sums[getleft(index)] + sums[getright(index)]
192192
end

src/taylor_gradient.jl

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
############################### matrix valued version ##########################
2+
# function BarnesHutFactorization(k::GradientKernel, x, y = x, D = nothing;
3+
# θ::Real = 1/4, leafsize::Int = BARNES_HUT_DEFAULT_LEAFSIZE)
4+
# xs = vector_of_static_vectors(x)
5+
# ys = x === y ? xs : vector_of_static_vectors(y)
6+
# Tree = BallTree(ys, leafsize = leafsize)
7+
# m = length(y)
8+
# XT, YT, KT, TT, DT, RT = typeof.((xs, ys, k, Tree, D, θ))
9+
# # w = zeros(length(m))
10+
# # i = zeros(Bool, m)
11+
# # WT, BT = typeof(w), typeof(i)
12+
# T = gramian_eltype(k, xs, ys)
13+
# F = BarnesHutFactorization{T, KT, XT, YT, TT, DT, RT}(k, xs, ys, Tree, D, θ) #, w, i)
14+
# end
15+
16+
function LinearAlgebra.mul!(y::AbstractVector{<:Real}, F::BarnesHutFactorization{<:Any, <:GradientKernel},
17+
x::AbstractVector{<:Real}, α::Real = 1, β::Real = 0)
18+
d = length(F.x[1])
19+
X, Y = reshape(x, d, :), reshape(y, d, :) # converting vector of reals to vector of vectors
20+
xx, yy = [c for c in eachcol(X)], [c for c in eachcol(Y)]
21+
mul!(yy, F, xx, α, β)
22+
return y
23+
end
24+
25+
function LinearAlgebra.mul!(y::AbstractVector{<:AbstractVector}, F::BarnesHutFactorization{<:Any, <:GradientKernel},
26+
x::AbstractVector{<:AbstractVector}, α::Real = 1, β::Real = 0)
27+
taylor!(y, F, x, α, β)
28+
end
29+
30+
# function Base.:*(F::BarnesHutFactorization{<:Number}, x::AbstractVector{<:Number})
31+
# T = promote_type(eltype(F), eltype(x))
32+
# y = zeros(T, size(F, 1))
33+
# mul!(y, F, x)
34+
# end
35+
36+
function taylor!(b::AbstractVector{<:AbstractVector}, F::BarnesHutFactorization{<:Any, <:GradientKernel},
37+
w::AbstractVector{<:AbstractVector}, α::Number = 1, β::Number = 0, θ::Real = F.θ)
38+
size(F, 2) == length(w) || throw(DimensionMismatch("length of w does not match second dimension of F: $(length(w))$(size(F, 2))"))
39+
# eltype(b) == promote_type(eltype(F), eltype(w)) ||
40+
# throw(TypeError("eltype of target vector b not equal to eltype of matrix-vector product: $(eltype(b)) and $(promote_type(eltype(F), eltype(w)))"))
41+
f_orders(r²) = derivatives(F.k.k, r², 3)
42+
sums_w = node_sums(w, F.Tree) # IDEA: could pre-allocate, worth it? is several orders of magnitude less expensive than multiply
43+
sums_w_r = weighted_node_sums(adjoint.(w), adjoint.(F.y), F.Tree) # need sum of outer products of F.y and w
44+
centers = get_hyper_centers(F)
45+
@. sums_w_r -= sums_w * adjoint(centers) # need to center the moments
46+
Gijs = [F[1, 1] for _ in 1:Base.Threads.nthreads()]
47+
for i in eachindex(F.x) # exactly 4 * length(y) allocations?
48+
if β == 0
49+
@. b[i] = 0 # this avoids trouble if b is initialized with NaN's, e.g. thorugh "similar"
50+
else
51+
@. b[i] *= β
52+
end
53+
Gij = Gijs[Base.Threads.threadid()]
54+
taylor_recursion!(b[i], Gij, 1, F.k, f_orders, F.x[i], F.y,
55+
w, sums_w, sums_w_r, θ, F.Tree, centers, α) # x[i] creates an allocation
56+
end
57+
if !isnothing(F.D) # if there is a diagonal correction, need to add it
58+
mul!(b, F.D, w, α, 1)
59+
end
60+
return b
61+
end
62+
63+
# barnes hut recursion for matrix-valued kernels, could merge with scalar version
64+
# bi is target vector corresponding to input point xi
65+
# Gij is temporary storage for evaluation of k(xi, y[j]), important if it is matrix valued
66+
# to avoid allocations
67+
function taylor_recursion!(bi::AbstractVector, Gij,
68+
index, k::GradientKernel, f_orders,
69+
xi, y::AbstractVector,
70+
w::AbstractVector{<:AbstractVector},
71+
sums_w::AbstractVector{<:AbstractVector},
72+
sums_w_r::AbstractVector{<:AbstractMatrix},
73+
θ::Real, T::BallTree, centers, α::Number)
74+
h = T.hyper_spheres[index]
75+
if isleaf(T.tree_data.n_internal_nodes, index) # do direct computation
76+
for i in get_leaf_range(T.tree_data, index)
77+
j = T.indices[i]
78+
# @time Gij = evaluate_block!(Gij, k, xi, y[j]) # k(xi, y[j])
79+
Gij = IsotropicGradientKernelElement{eltype(bi)}(k.k, xi, y[j])
80+
wj = w[j]
81+
mul!(bi, Gij, wj, α, 1)
82+
end
83+
return bi
84+
85+
elseif h.r < θ * euclidean(xi, centers[index]) # compress
86+
S_index = sums_w_r[index]
87+
ri = difference(xi, centers[index])
88+
sum_abs2_ri = sum(abs2, ri)
89+
# NOTE: this line is the only one that still allocates ~688 bytes
90+
f0, f1, f2, f3 = f_orders(sum_abs2_ri) # contains first and second order derivative evaluation
91+
92+
# Gij = evaluate_block!(Gij, k, xi, centers[index]) # k(xi, centers_of_mass[index])
93+
Gij = IsotropicGradientKernelElement{eltype(bi)}(k.k, xi, centers[index])
94+
95+
# zeroth order
96+
mul!(bi, Gij, sums_w[index], α, 1) # bi .+= α * k(xi, centers_of_mass[index]) * sums_w[index]
97+
# first order
98+
# this block has zero allocations
99+
mul!(bi, 2*f3*dot(ri, S_index, ri) + f2 * tr(S_index), ri, 4α, 1)
100+
mul!(bi, S_index, ri, 4α*f2, 1)
101+
mul!(bi, S_index', ri, 4α*f2, 1)
102+
return bi
103+
else # recurse NOTE: parallelizing here is not as efficient as parallelizing over target points
104+
taylor_recursion!(bi, Gij, getleft(index), k, f_orders, xi, y, w, sums_w, sums_w_r, θ, T, centers, α)
105+
taylor_recursion!(bi, Gij, getright(index), k, f_orders, xi, y, w, sums_w, sums_w_r, θ, T, centers, α)
106+
end
107+
end
108+
109+
# function node_mapreduce(x::AbstractVector, w::AbstractVector, T::BallTree, index::Int = 1)
110+
# length(x) == 0 && return zero(eltype(x))
111+
# sums = fill(w[1]*x[1]', length(T.hyper_spheres))
112+
# node_outer_products!(sums, x, w, T)
113+
# end
114+
#
115+
# # NOTE: x should either be vector of numbers or vector of static arrays
116+
# function node_mapreduce!(sums::AbstractVector{<:AbstractMatrix}, x::AbstractVector,
117+
# w::AbstractVector{<:Number}, T::BallTree, index::Int = 1)
118+
# if isleaf(T.tree_data.n_internal_nodes, index)
119+
# i = get_leaf_range(T.tree_data, index)
120+
# wi, xi = @views w[T.indices[i]], x[T.indices[i]]
121+
# sums[index] = wi * xi'
122+
# # adjoint.(w)' * adjoint.(x)
123+
# else
124+
# task = @spawn weighted_node_sums!(sums, x, w, T, getleft(index))
125+
# weighted_node_sums!(sums, x, w, T, getright(index))
126+
# wait(task)
127+
# sums[index] = sums[getleft(index)] + sums[getright(index)]
128+
# end
129+
# return sums
130+
# end

test/barneshut.jl

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
module TestBarnesHut
22
using LinearAlgebra
3+
using WoodburyFactorizations
34
using CovarianceFunctions
45
using CovarianceFunctions: BarnesHutFactorization, barneshut!, vector_of_static_vectors,
5-
node_sums, euclidean, GradientKernel, taylor!
6+
node_sums, euclidean, GradientKernel, taylor!, IsotropicGradientKernelElement
67
using NearestNeighbors
78
using NearestNeighbors: isleaf, getleft, getright, get_leaf_range
89
using Test
910

11+
# like barnes hut but puts 0 as far field contribution
1012
function barneshut_no_far_field!(b::AbstractVector, F::BarnesHutFactorization, w::AbstractVector,
1113
α::Number = 1, β::Number = 0, θ::Real = F.θ)
1214
D = length(eltype(F.x))
@@ -95,10 +97,9 @@ verbose = false
9597
nexp = 16
9698
err = zeros(nexp)
9799
err_no_split = zeros(nexp)
98-
err_hyp = zeros(nexp)
99-
err_hyp_no_split = zeros(nexp)
100100
err_nff = zeros(nexp)
101101
err_taylor = zeros(nexp)
102+
err_taylor_hyp = zeros(nexp)
102103
theta_array = range(1e-1, 1, length = nexp)
103104
for (i, θ) in enumerate(theta_array)
104105
barneshut!(b_bh, F, w, 1, 0, θ, split = true)
@@ -110,8 +111,11 @@ verbose = false
110111
barneshut_no_far_field!(b_bh, F, w, 1, 0, θ) # compare against pseudo barnes hut where far field = 0
111112
err_nff[i] = norm(b - b_bh)
112113

113-
taylor!(b_bh, F, w, 1, 0, θ) # compare against pseudo barnes hut where far field = 0
114+
taylor!(b_bh, F, w, 1, 0, θ, use_com = true)
114115
err_taylor[i] = norm(b - b_bh)
116+
117+
taylor!(b_bh, F, w, 1, 0, θ, use_com = false)
118+
err_taylor_hyp[i] = norm(b - b_bh)
115119
end
116120

117121
rel_err = err / norm(b)
@@ -122,29 +126,59 @@ verbose = false
122126
# using Plots
123127
# plotly()
124128
#
125-
# rel_err_nff = err_nff / norm(b)
126-
# rel_err_no_split = err_no_split / norm(b)
127-
# rel_err_taylor = err_taylor / norm(b)
128-
#
129+
# norm_b = norm(b)
130+
# rel_err_nff = err_nff / norm_b
131+
# rel_err_no_split = err_no_split / norm_b
132+
# rel_err_taylor = err_taylor / norm_b
133+
# rel_err_taylor_hyp = err_taylor_hyp / norm_b
134+
# #
129135
# plot(theta_array, rel_err, yscale = :log10, label = "barneshut", ylabel = "relative error", xlabel = "θ")
130136
# plot!(theta_array, rel_err_no_split, yscale = :log10, label = "no split")
131137
# plot!(theta_array, rel_err_nff, yscale = :log10, label = "sparse")
132138
# plot!(theta_array, rel_err_taylor, yscale = :log10, label = "taylor")
139+
# plot!(theta_array, rel_err_taylor_hyp, yscale = :log10, label = "taylor hyper-sphere centers")
133140
# gui()
134141

135142
end # testset weight vectors
136143

137-
138-
# @testset "gradient kernels" begin
139-
# k = CovarianceFunctions.Cauchy()
140-
# g = CovarianceFunctions.GradientKernel(k)
141-
#
142-
# F = BarnesHutFactorization(g, x)
143-
# @test F isa BarnesHutFactorization
144-
# @test eltype(F) <: Diagonal
145-
# @test size(F) == (n, n)
146-
# @test size(F[1, 1]) == (d_out, d_out)
147-
# end # testset matrix valued bh
144+
@testset "gradient kernels" begin
145+
146+
n = 1024
147+
d = 2
148+
x = randn(d, n)
149+
# k = CovarianceFunctions.Cauchy()
150+
k = CovarianceFunctions.EQ()
151+
g = CovarianceFunctions.GradientKernel(k)
152+
153+
F = BarnesHutFactorization(g, x)
154+
@test F isa BarnesHutFactorization
155+
@test eltype(F) <: IsotropicGradientKernelElement
156+
@test size(F) == (n, n)
157+
@test size(F[1, 1]) == (d, d)
158+
159+
a = [randn(d) for _ in 1:n]
160+
b = [zeros(d) for _ in 1:n]
161+
# F*a
162+
G = gramian(g, x)
163+
b_truth = deepcopy(b)
164+
# @time b_truth = G.A * a
165+
# @time b_truth = G.A * a
166+
mul!(b_truth, G.A, a)
167+
norm_b = sqrt(sum(sum.(abs2, b_truth)))
168+
169+
α, β = 1, 0
170+
θ = 0
171+
taylor!(b, F, a, α, β, θ)
172+
err = sqrt(sum(sum.(abs2, b - b_truth)))
173+
rel_err = err / norm_b
174+
@test rel_err < 1e-10
175+
176+
θ = 1/10
177+
taylor!(b, F, a, α, β, θ)
178+
err = sqrt(sum(sum.(abs2, b - b_truth)))
179+
rel_err = err / norm_b
180+
@test rel_err < 1e-3
181+
end # testset matrix valued bh
148182

149183
end # testset
150184

0 commit comments

Comments
 (0)