Skip to content

Commit a9edfbd

Browse files
replaced Woodbury with specialized GradientKernelElement, decreasing memory allocations drastically
1 parent c420c1c commit a9edfbd

12 files changed

+244
-130
lines changed

Manifest.toml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ version = "5.0.7"
3838

3939
[[deps.ArrayLayouts]]
4040
deps = ["FillArrays", "LinearAlgebra", "SparseArrays"]
41-
git-tree-sha1 = "8b921542ad44cba67f1487e2226446597e0a90af"
41+
git-tree-sha1 = "c23473c60476e62579c077534b9643ec400f792b"
4242
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
43-
version = "0.8.5"
43+
version = "0.8.6"
4444

4545
[[deps.Artifacts]]
4646
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
@@ -176,9 +176,9 @@ version = "1.0.3"
176176

177177
[[deps.DiffRules]]
178178
deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
179-
git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0"
179+
git-tree-sha1 = "28d605d9a0ac17118fe2c5e9ce0fbb76c3ceb120"
180180
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
181-
version = "1.10.0"
181+
version = "1.11.0"
182182

183183
[[deps.Distances]]
184184
deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"]
@@ -237,9 +237,9 @@ version = "0.1.1"
237237

238238
[[deps.ForwardDiff]]
239239
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
240-
git-tree-sha1 = "40d1546a45abd63490569695a86a2d93c2021e54"
240+
git-tree-sha1 = "34e6147e7686a101c245f12dba43b743c7afda96"
241241
uuid = "f6369f11-7733-5829-9624-2563aa707210"
242-
version = "0.10.26"
242+
version = "0.10.27"
243243

244244
[[deps.FunctionWrappers]]
245245
git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc"
@@ -396,9 +396,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
396396

397397
[[deps.LogExpFunctions]]
398398
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
399-
git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571"
399+
git-tree-sha1 = "44a7b7bb7dd1afe12bac119df6a7e540fa2c96bc"
400400
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
401-
version = "0.3.12"
401+
version = "0.3.13"
402402

403403
[[deps.Logging]]
404404
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -410,9 +410,9 @@ version = "0.4.10"
410410

411411
[[deps.MLUtils]]
412412
deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"]
413-
git-tree-sha1 = "32eeb46fa393ae36a4127c9442ade478c8d01117"
413+
git-tree-sha1 = "202617a5a49a8b5f3b4abf96621f2519b1592c74"
414414
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
415-
version = "0.2.3"
415+
version = "0.2.4"
416416

417417
[[deps.MPC_jll]]
418418
deps = ["Artifacts", "GMP_jll", "JLLWrappers", "Libdl", "MPFR_jll", "Pkg"]
@@ -533,9 +533,9 @@ version = "1.4.1"
533533

534534
[[deps.Parsers]]
535535
deps = ["Dates"]
536-
git-tree-sha1 = "3b429f37de37f1fc603cc1de4a799dc7fbe4c0b6"
536+
git-tree-sha1 = "1285416549ccfcdf0c50d4997a94331e88d68413"
537537
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
538-
version = "2.3.0"
538+
version = "2.3.1"
539539

540540
[[deps.Pkg]]
541541
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
@@ -658,9 +658,9 @@ version = "0.1.14"
658658

659659
[[deps.Static]]
660660
deps = ["IfElse"]
661-
git-tree-sha1 = "2114b1d8517764a8c4625a2e97f40640c7a301a7"
661+
git-tree-sha1 = "b1f1f60bf4f25d8b374480fb78c7b9785edf95fd"
662662
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
663-
version = "0.6.1"
663+
version = "0.6.2"
664664

665665
[[deps.StaticArrays]]
666666
deps = ["LinearAlgebra", "Random", "Statistics"]

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ BesselK = "0.3"
3030
BlockFactorizations = "1.2.1"
3131
DiffResults = "1.0"
3232
FillArrays = "0.12, 0.13"
33+
Flux = "0.13"
3334
ForwardDiff = "0.10"
35+
Functors = "0.2"
3436
IterativeSolvers = "0.9"
3537
KroneckerProducts = "1.0"
3638
LazyArrays = "0.22"

src/barneshut.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function BarnesHutFactorization(k, x, y = x, D = nothing; θ::Real = 1/4, leafsi
3030
# w = zeros(length(m))
3131
# i = zeros(Bool, m)
3232
# WT, BT = typeof(w), typeof(i)
33-
T = gramian_eltype(k, xs, ys)
33+
T = gramian_eltype(k, xs[1], ys[1])
3434
BarnesHutFactorization{T, KT, XT, YT, TT, DT, RT}(k, xs, ys, Tree, D, θ) #, w, i)
3535
end
3636
function BarnesHutFactorization(G::Gramian, θ::Real = 1/2; leafsize::Int = BARNES_HUT_DEFAULT_LEAFSIZE)
@@ -49,7 +49,7 @@ function LinearAlgebra.mul!(y::AbstractVector, F::BarnesHutFactorization, x::Abs
4949
taylor!(y, F, x, α, β)
5050
end
5151
end
52-
function Base.:*(F::BarnesHutFactorization, x::AbstractVector)
52+
function Base.:*(F::BarnesHutFactorization{<:Number}, x::AbstractVector{<:Number})
5353
T = promote_type(eltype(F), eltype(x))
5454
y = zeros(T, size(F, 1))
5555
mul!(y, F, x)
@@ -148,45 +148,45 @@ end
148148

149149
############################# centers of mass ##################################
150150
# this is a weighted sum, could be generalized to incorporate node_sums
151-
function compute_centers_of_mass(x::AbstractVector, w::AbstractVector, T::BallTree)
151+
function compute_centers_of_mass(w::AbstractVector, x::AbstractVector, T::BallTree)
152152
D = eltype(x) <: StaticVector ? length(eltype(x)) : length(x[1]) # if x is static vector
153153
com = [zero(MVector{D, Float64}) for _ in 1:length(T.hyper_spheres)]
154-
compute_centers_of_mass!(com, x, w, T)
154+
compute_centers_of_mass!(com, w, x, T)
155155
end
156156

157157
function compute_centers_of_mass(F::BarnesHutFactorization, w::AbstractVector)
158-
compute_centers_of_mass(F.y, w, F.Tree)
158+
compute_centers_of_mass(w, F.y, F.Tree)
159159
end
160160

161-
function compute_centers_of_mass!(com::AbstractVector, x::AbstractVector, w::AbstractVector, T::BallTree)
161+
function compute_centers_of_mass!(com::AbstractVector, w::AbstractVector, x::AbstractVector, T::BallTree)
162162
abs_w = abs.(w)
163-
weighted_node_sums!(com, x, abs_w, T)
163+
weighted_node_sums!(com, abs_w, x, T)
164164
sum_w = node_sums(abs_w, T)
165165
ε = eps(eltype(w)) # ensuring division by zero it not a problem
166166
@. com ./= sum_w + ε
167167
end
168168

169-
node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(x, Ones(length(x)), T)
169+
node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(Ones(length(x)), x, T)
170170
function node_sums!(sums, x::AbstractVector, T::BallTree)
171-
weighted_node_sums!(sums, x, Ones(length(x)), T)
171+
weighted_node_sums!(sums, Ones(length(x)), x, T)
172172
end
173173

174-
function weighted_node_sums(x::AbstractVector, w::AbstractVector, T::BallTree, index::Int = 1)
174+
function weighted_node_sums(w::AbstractVector, x::AbstractVector, T::BallTree, index::Int = 1)
175175
length(x) == 0 && return zero(eltype(x))
176-
sums = zeros(typeof(w[1]'x[1]), length(T.hyper_spheres))
177-
weighted_node_sums!(sums, x, w, T)
176+
sums = fill(zero(w[1]'x[1]), length(T.hyper_spheres))
177+
weighted_node_sums!(sums, w, x, T)
178178
end
179179

180180
# NOTE: x should either be vector of numbers or vector of static arrays
181-
function weighted_node_sums!(sums::AbstractVector, x::AbstractVector,
182-
w::AbstractVector{<:Number}, T::BallTree, index::Int = 1)
181+
function weighted_node_sums!(sums::AbstractVector, w::AbstractVector,
182+
x::AbstractVector, T::BallTree, index::Int = 1)
183183
if isleaf(T.tree_data.n_internal_nodes, index)
184184
i = get_leaf_range(T.tree_data, index)
185185
wi, xi = @views w[T.indices[i]], x[T.indices[i]]
186186
sums[index] = wi'xi
187187
else
188-
task = @spawn weighted_node_sums!(sums, x, w, T, getleft(index))
189-
weighted_node_sums!(sums, x, w, T, getright(index))
188+
task = @spawn weighted_node_sums!(sums, w, x, T, getleft(index))
189+
weighted_node_sums!(sums, w, x, T, getright(index))
190190
wait(task)
191191
sums[index] = sums[getleft(index)] + sums[getright(index)]
192192
end

0 commit comments

Comments
 (0)