Skip to content

Fast Taylor decomposition for gradient kernels #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CovarianceFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ include("taylor.jl")
include("gradient.jl")
include("gradient_algebra.jl")
include("hessian.jl")
# include("taylor_gradient.jl") # fast MVM algorithm for isotropic GradientKernel Gramians
include("taylor_gradient.jl") # fast MVM algorithm for isotropic GradientKernel Gramians
include("separable.jl")

end # CovarianceFunctions
32 changes: 16 additions & 16 deletions src/barneshut.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function BarnesHutFactorization(k, x, y = x, D = nothing; θ::Real = 1/4, leafsi
# w = zeros(length(m))
# i = zeros(Bool, m)
# WT, BT = typeof(w), typeof(i)
T = gramian_eltype(k, xs, ys)
T = gramian_eltype(k, xs[1], ys[1])
BarnesHutFactorization{T, KT, XT, YT, TT, DT, RT}(k, xs, ys, Tree, D, θ) #, w, i)
end
function BarnesHutFactorization(G::Gramian, θ::Real = 1/2; leafsize::Int = BARNES_HUT_DEFAULT_LEAFSIZE)
Expand All @@ -49,7 +49,7 @@ function LinearAlgebra.mul!(y::AbstractVector, F::BarnesHutFactorization, x::Abs
taylor!(y, F, x, α, β)
end
end
function Base.:*(F::BarnesHutFactorization, x::AbstractVector)
function Base.:*(F::BarnesHutFactorization{<:Number}, x::AbstractVector{<:Number})
T = promote_type(eltype(F), eltype(x))
y = zeros(T, size(F, 1))
mul!(y, F, x)
Expand Down Expand Up @@ -148,45 +148,45 @@ end

############################# centers of mass ##################################
# this is a weighted sum, could be generalized to incorporate node_sums
function compute_centers_of_mass(x::AbstractVector, w::AbstractVector, T::BallTree)
function compute_centers_of_mass(w::AbstractVector, x::AbstractVector, T::BallTree)
D = eltype(x) <: StaticVector ? length(eltype(x)) : length(x[1]) # if x is static vector
com = [zero(MVector{D, Float64}) for _ in 1:length(T.hyper_spheres)]
compute_centers_of_mass!(com, x, w, T)
compute_centers_of_mass!(com, w, x, T)
end

function compute_centers_of_mass(F::BarnesHutFactorization, w::AbstractVector)
compute_centers_of_mass(F.y, w, F.Tree)
compute_centers_of_mass(w, F.y, F.Tree)
end

function compute_centers_of_mass!(com::AbstractVector, x::AbstractVector, w::AbstractVector, T::BallTree)
function compute_centers_of_mass!(com::AbstractVector, w::AbstractVector, x::AbstractVector, T::BallTree)
abs_w = abs.(w)
weighted_node_sums!(com, x, abs_w, T)
weighted_node_sums!(com, abs_w, x, T)
sum_w = node_sums(abs_w, T)
ε = eps(eltype(w)) # ensuring division by zero it not a problem
@. com ./= sum_w + ε
end

node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(x, Ones(length(x)), T)
node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(Ones(length(x)), x, T)
function node_sums!(sums, x::AbstractVector, T::BallTree)
weighted_node_sums!(sums, x, Ones(length(x)), T)
weighted_node_sums!(sums, Ones(length(x)), x, T)
end

function weighted_node_sums(x::AbstractVector, w::AbstractVector, T::BallTree, index::Int = 1)
function weighted_node_sums(w::AbstractVector, x::AbstractVector, T::BallTree, index::Int = 1)
length(x) == 0 && return zero(eltype(x))
sums = zeros(typeof(w[1]'x[1]), length(T.hyper_spheres))
weighted_node_sums!(sums, x, w, T)
sums = fill(zero(w[1]'x[1]), length(T.hyper_spheres))
weighted_node_sums!(sums, w, x, T)
end

# NOTE: x should either be vector of numbers or vector of static arrays
function weighted_node_sums!(sums::AbstractVector, x::AbstractVector,
w::AbstractVector{<:Number}, T::BallTree, index::Int = 1)
function weighted_node_sums!(sums::AbstractVector, w::AbstractVector,
x::AbstractVector, T::BallTree, index::Int = 1)
if isleaf(T.tree_data.n_internal_nodes, index)
i = get_leaf_range(T.tree_data, index)
wi, xi = @views w[T.indices[i]], x[T.indices[i]]
sums[index] = wi'xi
else
task = @spawn weighted_node_sums!(sums, x, w, T, getleft(index))
weighted_node_sums!(sums, x, w, T, getright(index))
task = @spawn weighted_node_sums!(sums, w, x, T, getleft(index))
weighted_node_sums!(sums, w, x, T, getright(index))
wait(task)
sums[index] = sums[getleft(index)] + sums[getright(index)]
end
Expand Down
130 changes: 130 additions & 0 deletions src/taylor_gradient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
############################### matrix valued version ##########################
# function BarnesHutFactorization(k::GradientKernel, x, y = x, D = nothing;
# θ::Real = 1/4, leafsize::Int = BARNES_HUT_DEFAULT_LEAFSIZE)
# xs = vector_of_static_vectors(x)
# ys = x === y ? xs : vector_of_static_vectors(y)
# Tree = BallTree(ys, leafsize = leafsize)
# m = length(y)
# XT, YT, KT, TT, DT, RT = typeof.((xs, ys, k, Tree, D, θ))
# # w = zeros(length(m))
# # i = zeros(Bool, m)
# # WT, BT = typeof(w), typeof(i)
# T = gramian_eltype(k, xs, ys)
# F = BarnesHutFactorization{T, KT, XT, YT, TT, DT, RT}(k, xs, ys, Tree, D, θ) #, w, i)
# end

function LinearAlgebra.mul!(y::AbstractVector{<:Real}, F::BarnesHutFactorization{<:Any, <:GradientKernel},
x::AbstractVector{<:Real}, α::Real = 1, β::Real = 0)
d = length(F.x[1])
X, Y = reshape(x, d, :), reshape(y, d, :) # converting vector of reals to vector of vectors
xx, yy = [c for c in eachcol(X)], [c for c in eachcol(Y)]
mul!(yy, F, xx, α, β)
return y
end

function LinearAlgebra.mul!(y::AbstractVector{<:AbstractVector}, F::BarnesHutFactorization{<:Any, <:GradientKernel},
x::AbstractVector{<:AbstractVector}, α::Real = 1, β::Real = 0)
taylor!(y, F, x, α, β)
end

# function Base.:*(F::BarnesHutFactorization{<:Number}, x::AbstractVector{<:Number})
# T = promote_type(eltype(F), eltype(x))
# y = zeros(T, size(F, 1))
# mul!(y, F, x)
# end

function taylor!(b::AbstractVector{<:AbstractVector}, F::BarnesHutFactorization{<:Any, <:GradientKernel},
w::AbstractVector{<:AbstractVector}, α::Number = 1, β::Number = 0, θ::Real = F.θ)
size(F, 2) == length(w) || throw(DimensionMismatch("length of w does not match second dimension of F: $(length(w)) ≠ $(size(F, 2))"))
# eltype(b) == promote_type(eltype(F), eltype(w)) ||
# throw(TypeError("eltype of target vector b not equal to eltype of matrix-vector product: $(eltype(b)) and $(promote_type(eltype(F), eltype(w)))"))
f_orders(r²) = derivatives(F.k.k, r², 3)
sums_w = node_sums(w, F.Tree) # IDEA: could pre-allocate, worth it? is several orders of magnitude less expensive than multiply
sums_w_r = weighted_node_sums(adjoint.(w), adjoint.(F.y), F.Tree) # need sum of outer products of F.y and w
centers = get_hyper_centers(F)
@. sums_w_r -= sums_w * adjoint(centers) # need to center the moments
Gijs = [F[1, 1] for _ in 1:Base.Threads.nthreads()]
for i in eachindex(F.x) # exactly 4 * length(y) allocations?
if β == 0
@. b[i] = 0 # this avoids trouble if b is initialized with NaN's, e.g. thorugh "similar"
else
@. b[i] *= β
end
Gij = Gijs[Base.Threads.threadid()]
taylor_recursion!(b[i], Gij, 1, F.k, f_orders, F.x[i], F.y,
w, sums_w, sums_w_r, θ, F.Tree, centers, α) # x[i] creates an allocation
end
if !isnothing(F.D) # if there is a diagonal correction, need to add it
mul!(b, F.D, w, α, 1)
end
return b
end

# barnes hut recursion for matrix-valued kernels, could merge with scalar version
# bi is target vector corresponding to input point xi
# Gij is temporary storage for evaluation of k(xi, y[j]), important if it is matrix valued
# to avoid allocations
function taylor_recursion!(bi::AbstractVector, Gij,
index, k::GradientKernel, f_orders,
xi, y::AbstractVector,
w::AbstractVector{<:AbstractVector},
sums_w::AbstractVector{<:AbstractVector},
sums_w_r::AbstractVector{<:AbstractMatrix},
θ::Real, T::BallTree, centers, α::Number)
h = T.hyper_spheres[index]
if isleaf(T.tree_data.n_internal_nodes, index) # do direct computation
for i in get_leaf_range(T.tree_data, index)
j = T.indices[i]
# @time Gij = evaluate_block!(Gij, k, xi, y[j]) # k(xi, y[j])
Gij = IsotropicGradientKernelElement{eltype(bi)}(k.k, xi, y[j])
wj = w[j]
mul!(bi, Gij, wj, α, 1)
end
return bi

elseif h.r < θ * euclidean(xi, centers[index]) # compress
S_index = sums_w_r[index]
ri = difference(xi, centers[index])
sum_abs2_ri = sum(abs2, ri)
# NOTE: this line is the only one that still allocates ~688 bytes
f0, f1, f2, f3 = f_orders(sum_abs2_ri) # contains first and second order derivative evaluation

# Gij = evaluate_block!(Gij, k, xi, centers[index]) # k(xi, centers_of_mass[index])
Gij = IsotropicGradientKernelElement{eltype(bi)}(k.k, xi, centers[index])

# zeroth order
mul!(bi, Gij, sums_w[index], α, 1) # bi .+= α * k(xi, centers_of_mass[index]) * sums_w[index]
# first order
# this block has zero allocations
mul!(bi, 2*f3*dot(ri, S_index, ri) + f2 * tr(S_index), ri, 4α, 1)
mul!(bi, S_index, ri, 4α*f2, 1)
mul!(bi, S_index', ri, 4α*f2, 1)
return bi
else # recurse NOTE: parallelizing here is not as efficient as parallelizing over target points
taylor_recursion!(bi, Gij, getleft(index), k, f_orders, xi, y, w, sums_w, sums_w_r, θ, T, centers, α)
taylor_recursion!(bi, Gij, getright(index), k, f_orders, xi, y, w, sums_w, sums_w_r, θ, T, centers, α)
end
end

# function node_mapreduce(x::AbstractVector, w::AbstractVector, T::BallTree, index::Int = 1)
# length(x) == 0 && return zero(eltype(x))
# sums = fill(w[1]*x[1]', length(T.hyper_spheres))
# node_outer_products!(sums, x, w, T)
# end
#
# # NOTE: x should either be vector of numbers or vector of static arrays
# function node_mapreduce!(sums::AbstractVector{<:AbstractMatrix}, x::AbstractVector,
# w::AbstractVector{<:Number}, T::BallTree, index::Int = 1)
# if isleaf(T.tree_data.n_internal_nodes, index)
# i = get_leaf_range(T.tree_data, index)
# wi, xi = @views w[T.indices[i]], x[T.indices[i]]
# sums[index] = wi * xi'
# # adjoint.(w)' * adjoint.(x)
# else
# task = @spawn weighted_node_sums!(sums, x, w, T, getleft(index))
# weighted_node_sums!(sums, x, w, T, getright(index))
# wait(task)
# sums[index] = sums[getleft(index)] + sums[getright(index)]
# end
# return sums
# end
72 changes: 53 additions & 19 deletions test/barneshut.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module TestBarnesHut
using LinearAlgebra
using WoodburyFactorizations
using CovarianceFunctions
using CovarianceFunctions: BarnesHutFactorization, barneshut!, vector_of_static_vectors,
node_sums, euclidean, GradientKernel, taylor!
node_sums, euclidean, GradientKernel, taylor!, IsotropicGradientKernelElement
using NearestNeighbors
using NearestNeighbors: isleaf, getleft, getright, get_leaf_range
using Test

# like barnes hut but puts 0 as far field contribution
function barneshut_no_far_field!(b::AbstractVector, F::BarnesHutFactorization, w::AbstractVector,
α::Number = 1, β::Number = 0, θ::Real = F.θ)
D = length(eltype(F.x))
Expand Down Expand Up @@ -95,10 +97,9 @@ verbose = false
nexp = 16
err = zeros(nexp)
err_no_split = zeros(nexp)
err_hyp = zeros(nexp)
err_hyp_no_split = zeros(nexp)
err_nff = zeros(nexp)
err_taylor = zeros(nexp)
err_taylor_hyp = zeros(nexp)
theta_array = range(1e-1, 1, length = nexp)
for (i, θ) in enumerate(theta_array)
barneshut!(b_bh, F, w, 1, 0, θ, split = true)
Expand All @@ -110,8 +111,11 @@ verbose = false
barneshut_no_far_field!(b_bh, F, w, 1, 0, θ) # compare against pseudo barnes hut where far field = 0
err_nff[i] = norm(b - b_bh)

taylor!(b_bh, F, w, 1, 0, θ) # compare against pseudo barnes hut where far field = 0
taylor!(b_bh, F, w, 1, 0, θ, use_com = true)
err_taylor[i] = norm(b - b_bh)

taylor!(b_bh, F, w, 1, 0, θ, use_com = false)
err_taylor_hyp[i] = norm(b - b_bh)
end

rel_err = err / norm(b)
Expand All @@ -122,29 +126,59 @@ verbose = false
# using Plots
# plotly()
#
# rel_err_nff = err_nff / norm(b)
# rel_err_no_split = err_no_split / norm(b)
# rel_err_taylor = err_taylor / norm(b)
#
# norm_b = norm(b)
# rel_err_nff = err_nff / norm_b
# rel_err_no_split = err_no_split / norm_b
# rel_err_taylor = err_taylor / norm_b
# rel_err_taylor_hyp = err_taylor_hyp / norm_b
# #
# plot(theta_array, rel_err, yscale = :log10, label = "barneshut", ylabel = "relative error", xlabel = "θ")
# plot!(theta_array, rel_err_no_split, yscale = :log10, label = "no split")
# plot!(theta_array, rel_err_nff, yscale = :log10, label = "sparse")
# plot!(theta_array, rel_err_taylor, yscale = :log10, label = "taylor")
# plot!(theta_array, rel_err_taylor_hyp, yscale = :log10, label = "taylor hyper-sphere centers")
# gui()

end # testset weight vectors


# @testset "gradient kernels" begin
# k = CovarianceFunctions.Cauchy()
# g = CovarianceFunctions.GradientKernel(k)
#
# F = BarnesHutFactorization(g, x)
# @test F isa BarnesHutFactorization
# @test eltype(F) <: Diagonal
# @test size(F) == (n, n)
# @test size(F[1, 1]) == (d_out, d_out)
# end # testset matrix valued bh
@testset "gradient kernels" begin

n = 1024
d = 2
x = randn(d, n)
# k = CovarianceFunctions.Cauchy()
k = CovarianceFunctions.EQ()
g = CovarianceFunctions.GradientKernel(k)

F = BarnesHutFactorization(g, x)
@test F isa BarnesHutFactorization
@test eltype(F) <: IsotropicGradientKernelElement
@test size(F) == (n, n)
@test size(F[1, 1]) == (d, d)

a = [randn(d) for _ in 1:n]
b = [zeros(d) for _ in 1:n]
# F*a
G = gramian(g, x)
b_truth = deepcopy(b)
# @time b_truth = G.A * a
# @time b_truth = G.A * a
mul!(b_truth, G.A, a)
norm_b = sqrt(sum(sum.(abs2, b_truth)))

α, β = 1, 0
θ = 0
taylor!(b, F, a, α, β, θ)
err = sqrt(sum(sum.(abs2, b - b_truth)))
rel_err = err / norm_b
@test rel_err < 1e-10

θ = 1/10
taylor!(b, F, a, α, β, θ)
err = sqrt(sum(sum.(abs2, b - b_truth)))
rel_err = err / norm_b
@test rel_err < 1e-3
end # testset matrix valued bh

end # testset

Expand Down