@@ -30,7 +30,7 @@ function BarnesHutFactorization(k, x, y = x, D = nothing; θ::Real = 1/4, leafsi
30
30
# w = zeros(length(m))
31
31
# i = zeros(Bool, m)
32
32
# WT, BT = typeof(w), typeof(i)
33
- T = gramian_eltype (k, xs, ys)
33
+ T = gramian_eltype (k, xs[ 1 ] , ys[ 1 ] )
34
34
BarnesHutFactorization {T, KT, XT, YT, TT, DT, RT} (k, xs, ys, Tree, D, θ) # , w, i)
35
35
end
36
36
function BarnesHutFactorization (G:: Gramian , θ:: Real = 1 / 2 ; leafsize:: Int = BARNES_HUT_DEFAULT_LEAFSIZE)
@@ -49,7 +49,7 @@ function LinearAlgebra.mul!(y::AbstractVector, F::BarnesHutFactorization, x::Abs
49
49
taylor! (y, F, x, α, β)
50
50
end
51
51
end
52
- function Base.:* (F:: BarnesHutFactorization , x:: AbstractVector )
52
+ function Base.:* (F:: BarnesHutFactorization{<:Number} , x:: AbstractVector{<:Number} )
53
53
T = promote_type (eltype (F), eltype (x))
54
54
y = zeros (T, size (F, 1 ))
55
55
mul! (y, F, x)
@@ -148,45 +148,45 @@ end
148
148
149
149
# ############################ centers of mass ##################################
150
150
# this is a weighted sum, could be generalized to incorporate node_sums
151
- function compute_centers_of_mass (x :: AbstractVector , w :: AbstractVector , T:: BallTree )
151
+ function compute_centers_of_mass (w :: AbstractVector , x :: AbstractVector , T:: BallTree )
152
152
D = eltype (x) <: StaticVector ? length (eltype (x)) : length (x[1 ]) # if x is static vector
153
153
com = [zero (MVector{D, Float64}) for _ in 1 : length (T. hyper_spheres)]
154
- compute_centers_of_mass! (com, x, w , T)
154
+ compute_centers_of_mass! (com, w, x , T)
155
155
end
156
156
157
157
function compute_centers_of_mass (F:: BarnesHutFactorization , w:: AbstractVector )
158
- compute_centers_of_mass (F. y, w , F. Tree)
158
+ compute_centers_of_mass (w, F. y, F. Tree)
159
159
end
160
160
161
- function compute_centers_of_mass! (com:: AbstractVector , x :: AbstractVector , w :: AbstractVector , T:: BallTree )
161
+ function compute_centers_of_mass! (com:: AbstractVector , w :: AbstractVector , x :: AbstractVector , T:: BallTree )
162
162
abs_w = abs .(w)
163
- weighted_node_sums! (com, x, abs_w , T)
163
+ weighted_node_sums! (com, abs_w, x , T)
164
164
sum_w = node_sums (abs_w, T)
165
165
ε = eps (eltype (w)) # ensuring division by zero it not a problem
166
166
@. com ./= sum_w + ε
167
167
end
168
168
169
- node_sums (x:: AbstractVector , T:: BallTree ) = weighted_node_sums (x, Ones (length (x)), T)
169
+ node_sums (x:: AbstractVector , T:: BallTree ) = weighted_node_sums (Ones (length (x)), x , T)
170
170
function node_sums! (sums, x:: AbstractVector , T:: BallTree )
171
- weighted_node_sums! (sums, x, Ones (length (x)), T)
171
+ weighted_node_sums! (sums, Ones (length (x)), x , T)
172
172
end
173
173
174
- function weighted_node_sums (x :: AbstractVector , w :: AbstractVector , T:: BallTree , index:: Int = 1 )
174
+ function weighted_node_sums (w :: AbstractVector , x :: AbstractVector , T:: BallTree , index:: Int = 1 )
175
175
length (x) == 0 && return zero (eltype (x))
176
- sums = zeros ( typeof (w[1 ]' x[1 ]), length (T. hyper_spheres))
177
- weighted_node_sums! (sums, x, w , T)
176
+ sums = fill ( zero (w[1 ]' x[1 ]), length (T. hyper_spheres))
177
+ weighted_node_sums! (sums, w, x , T)
178
178
end
179
179
180
180
# NOTE: x should either be vector of numbers or vector of static arrays
181
- function weighted_node_sums! (sums:: AbstractVector , x :: AbstractVector ,
182
- w :: AbstractVector{<:Number} , T:: BallTree , index:: Int = 1 )
181
+ function weighted_node_sums! (sums:: AbstractVector , w :: AbstractVector ,
182
+ x :: AbstractVector , T:: BallTree , index:: Int = 1 )
183
183
if isleaf (T. tree_data. n_internal_nodes, index)
184
184
i = get_leaf_range (T. tree_data, index)
185
185
wi, xi = @views w[T. indices[i]], x[T. indices[i]]
186
186
sums[index] = wi' xi
187
187
else
188
- task = @spawn weighted_node_sums! (sums, x, w , T, getleft (index))
189
- weighted_node_sums! (sums, x, w , T, getright (index))
188
+ task = @spawn weighted_node_sums! (sums, w, x , T, getleft (index))
189
+ weighted_node_sums! (sums, w, x , T, getright (index))
190
190
wait (task)
191
191
sums[index] = sums[getleft (index)] + sums[getright (index)]
192
192
end
0 commit comments